Skip to content

Commit f030839

Browse files
committed
feat(genai): Add TextGen socket example using Audio Input
1 parent d037a20 commit f030839

2 files changed

Lines changed: 164 additions & 0 deletions

File tree

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import os
17+
import subprocess
18+
19+
20+
def get_bearer_token():
21+
command = "gcloud auth application-default print-access-token"
22+
result = subprocess.check_output(command, shell=True, text=True)
23+
return result.strip()
24+
25+
26+
# get bearer token
27+
BEARER_TOKEN = get_bearer_token()
28+
29+
30+
async def generate_content() -> str:
31+
"""
32+
Connects to the Gemini API via WebSocket, sends a text prompt,
33+
and returns the aggregated text response.
34+
"""
35+
# [START googlegenaisdk_live_websocket_textgen_with_audio]
36+
import base64
37+
import json
38+
39+
from websockets.asyncio.client import connect
40+
from scipy.io import wavfile
41+
42+
43+
def read_wavefile(filepath):
44+
# Read the .wav file.
45+
rate, data = wavfile.read(filepath)
46+
# Convert the NumPy array of audio samples back to raw bytes
47+
raw_audio_bytes = data.tobytes()
48+
# Encode the raw bytes to a base64 string.
49+
# The result needs to be decoded from bytes to a UTF-8 string
50+
base64_encoded_data = base64.b64encode(raw_audio_bytes).decode('ascii')
51+
mime_type = f"audio/pcm;rate={rate}"
52+
return base64_encoded_data, mime_type
53+
54+
# Configuration Constants
55+
PROJECT_ID = os.getenv("GOOGLE_SAMPLES_PROJECT")
56+
LOCATION = "us-central1"
57+
GEMINI_MODEL_NAME = "gemini-2.0-flash-live-preview-04-09"
58+
# To generate a bearer token, use:
59+
# $ gcloud auth application-default print-access-token
60+
# It's recommended to fetch this token dynamically rather than hardcoding.
61+
# BEARER_TOKEN = "ya29.a0AW4XtxhRb1s51TxLPnj..."
62+
63+
# Websocket Configuration
64+
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
65+
WEBSOCKET_SERVICE_URL = (
66+
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
67+
)
68+
69+
# Websocket Authentication
70+
headers = {
71+
"Content-Type": "application/json",
72+
"Authorization": f"Bearer {BEARER_TOKEN}",
73+
}
74+
75+
# Model Configuration
76+
model_path = (
77+
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
78+
)
79+
model_generation_config = {"response_modalities": ["TEXT"]}
80+
81+
async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
82+
# 1. Send setup configuration
83+
websocket_config = {
84+
"setup": {
85+
"model": model_path,
86+
"generation_config": model_generation_config,
87+
}
88+
}
89+
await websocket_session.send(json.dumps(websocket_config))
90+
91+
# 2. Receive setup response
92+
raw_setup_response = await websocket_session.recv()
93+
setup_response = json.loads(
94+
raw_setup_response.decode("utf-8")
95+
if isinstance(raw_setup_response, bytes)
96+
else raw_setup_response
97+
)
98+
print(f"Setup Response: {setup_response}")
99+
# Example response: {'setupComplete': {}}
100+
if "setupComplete" not in setup_response:
101+
print(f"Setup failed: {setup_response}")
102+
return "Error: WebSocket setup failed."
103+
104+
# 3. Send audio message
105+
encoded_audio_message, mime_type = read_wavefile("hello_gemini_are_you_there.wav")
106+
# Example audio message: "Hello? Gemini are you there?"
107+
108+
user_message = {
109+
"client_content": {
110+
"turns": [
111+
{
112+
"role": "user",
113+
"parts": [
114+
{
115+
"inlineData": {
116+
"mimeType": mime_type, # Example value: "audio/pcm;rate=24000"
117+
"data": encoded_audio_message, # Example value: "AQD//wAAAAAAA....."
118+
}
119+
}
120+
],
121+
}
122+
],
123+
"turn_complete": True,
124+
}
125+
}
126+
await websocket_session.send(json.dumps(user_message))
127+
128+
# 4. Receive model response
129+
aggregated_response_parts = []
130+
async for raw_response_chunk in websocket_session:
131+
response_chunk = json.loads(raw_response_chunk.decode("utf-8"))
132+
133+
server_content = response_chunk.get("serverContent")
134+
if not server_content:
135+
# This might indicate an error or an unexpected message format
136+
print(f"Received non-serverContent message or empty content: {response_chunk}")
137+
break
138+
139+
# Collect text responses
140+
model_turn = server_content.get("modelTurn")
141+
if model_turn and "parts" in model_turn and model_turn["parts"]:
142+
aggregated_response_parts.append(model_turn["parts"][0].get("text", ""))
143+
144+
# End of response
145+
if server_content.get("turnComplete"):
146+
break
147+
148+
final_response_text = "".join(aggregated_response_parts)
149+
print(f"Response: {final_response_text}")
150+
# Example response:
151+
# Setup Response: {'setupComplete': {}}
152+
# Response: Hey there. What's on your mind today?
153+
# [END googlegenaisdk_live_websocket_textgen_with_audio]
154+
return final_response_text
155+
156+
157+
if __name__ == "__main__":
158+
asyncio.run(generate_content())

genai/live/test_live_examples.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222

2323
import live_with_txt
24+
import live_websocket_textgen_with_audio
2425
import live_websocket_textgen_with_txt
2526
import live_websocket_audiogen_with_txt
2627
import live_websocket_audiotranscript_with_txt
@@ -36,6 +37,11 @@ async def test_live_with_text() -> None:
3637
assert await live_with_txt.generate_content()
3738

3839

40+
@pytest.mark.asyncio
41+
async def test_live_websocket_textgen_with_audio() -> None:
42+
assert await live_websocket_textgen_with_audio.generate_content()
43+
44+
3945
@pytest.mark.asyncio
4046
async def test_live_websocket_textgen_with_txt() -> None:
4147
assert await live_websocket_textgen_with_txt.generate_content()

0 commit comments

Comments
 (0)