Skip to content

Commit d037a20

Browse files
committed
feat(genai): Live API WebSocket Example
1 parent 26d6aa4 commit d037a20

6 files changed

Lines changed: 462 additions & 3 deletions
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import os
17+
import subprocess
18+
19+
20+
def get_bearer_token():
21+
command = "gcloud auth application-default print-access-token"
22+
result = subprocess.check_output(command, shell=True, text=True)
23+
return result.strip()
24+
25+
26+
# get bearer token
27+
BEARER_TOKEN = get_bearer_token()
28+
29+
30+
async def generate_content() -> str:
31+
"""
32+
Connects to the Gemini API via WebSocket, sends a text prompt,
33+
and returns the aggregated text response.
34+
"""
35+
# [START googlegenaisdk_live_audiogen_websocket_with_txt]
36+
import base64
37+
import json
38+
import numpy as np
39+
40+
from websockets.asyncio.client import connect
41+
from scipy.io.wavfile import write as save_audio
42+
43+
# Configuration Constants
44+
PROJECT_ID = os.getenv("GOOGLE_SAMPLES_PROJECT")
45+
LOCATION = "us-central1"
46+
GEMINI_MODEL_NAME = "gemini-2.0-flash-live-preview-04-09"
47+
# To generate a bearer token, use:
48+
# $ gcloud auth application-default print-access-token
49+
# It's recommended to fetch this token dynamically rather than hardcoding.
50+
# BEARER_TOKEN = "ya29.a0AW4XtxhRb1s51TxLPnj..."
51+
52+
# Websocket Configuration
53+
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
54+
WEBSOCKET_SERVICE_URL = (
55+
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
56+
)
57+
58+
# Websocket Authentication
59+
headers = {
60+
"Content-Type": "application/json",
61+
"Authorization": f"Bearer {BEARER_TOKEN}",
62+
}
63+
64+
# Model Configuration
65+
model_path = (
66+
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
67+
)
68+
model_generation_config = {
69+
"response_modalities": ["AUDIO"],
70+
"speech_config": {
71+
"voice_config": {"prebuilt_voice_config": {"voice_name": "Aoede"}},
72+
"language_code": "es-ES",
73+
},
74+
}
75+
76+
async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
77+
# 1. Send setup configuration
78+
websocket_config = {
79+
"setup": {
80+
"model": model_path,
81+
"generation_config": model_generation_config,
82+
}
83+
}
84+
await websocket_session.send(json.dumps(websocket_config))
85+
86+
# 2. Receive setup response
87+
raw_setup_response = await websocket_session.recv()
88+
setup_response = json.loads(
89+
raw_setup_response.decode("utf-8")
90+
if isinstance(raw_setup_response, bytes)
91+
else raw_setup_response
92+
)
93+
print(f"Setup Response: {setup_response}")
94+
# Example response: {'setupComplete': {}}
95+
if "setupComplete" not in setup_response:
96+
print(f"Setup failed: {setup_response}")
97+
return "Error: WebSocket setup failed."
98+
99+
# 3. Send text message
100+
text_input = "Hello? Gemini are you there?"
101+
print(f"Input: {text_input}")
102+
103+
user_message = {
104+
"client_content": {
105+
"turns": [{"role": "user", "parts": [{"text": text_input}]}],
106+
"turn_complete": True,
107+
}
108+
}
109+
await websocket_session.send(json.dumps(user_message))
110+
111+
# 4. Receive model response
112+
aggregated_response_parts = []
113+
async for raw_response_chunk in websocket_session:
114+
response_chunk = json.loads(raw_response_chunk.decode("utf-8"))
115+
116+
server_content = response_chunk.get("serverContent")
117+
if not server_content:
118+
# This might indicate an error or an unexpected message format
119+
print(f"Received non-serverContent message or empty content: {response_chunk}")
120+
break
121+
122+
# Collect audio chunks
123+
model_turn = server_content.get("modelTurn")
124+
if model_turn and "parts" in model_turn and model_turn["parts"]:
125+
for part in model_turn["parts"]:
126+
if part["inlineData"]["mimeType"] == "audio/pcm":
127+
audio_chunk = base64.b64decode(part["inlineData"]["data"])
128+
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
129+
130+
# End of response
131+
if server_content.get("turnComplete"):
132+
break
133+
134+
# Save audio to a file
135+
if aggregated_response_parts:
136+
save_audio("output.wav", 24000, np.concatenate(aggregated_response_parts))
137+
# Example response:
138+
# Setup Response: {'setupComplete': {}}
139+
# Input: Hello? Gemini are you there?
140+
# Audio Response: Hello there. I'm here. What can I do for you today?
141+
# [END googlegenaisdk_live_audiogen_websocket_with_txt]
142+
return "output.wav"
143+
144+
145+
if __name__ == "__main__":
146+
asyncio.run(generate_content())
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import os
17+
import subprocess
18+
19+
20+
def get_bearer_token():
21+
command = "gcloud auth application-default print-access-token"
22+
result = subprocess.check_output(command, shell=True, text=True)
23+
return result.strip()
24+
25+
26+
# get bearer token
27+
BEARER_TOKEN = get_bearer_token()
28+
29+
30+
async def generate_content() -> str:
31+
"""
32+
Connects to the Gemini API via WebSocket, sends a text prompt,
33+
and returns the aggregated text response.
34+
"""
35+
# [START googlegenaisdk_live_websocket_audiotranscript_with_txt]
36+
import base64
37+
import json
38+
import numpy as np
39+
40+
from websockets.asyncio.client import connect
41+
from scipy.io.wavfile import write as save_audio
42+
43+
# Configuration Constants
44+
PROJECT_ID = os.getenv("GOOGLE_SAMPLES_PROJECT")
45+
LOCATION = "us-central1"
46+
GEMINI_MODEL_NAME = "gemini-2.0-flash-live-preview-04-09"
47+
# To generate a bearer token, use:
48+
# $ gcloud auth application-default print-access-token
49+
# It's recommended to fetch this token dynamically rather than hardcoding.
50+
# BEARER_TOKEN = "ya29.a0AW4XtxhRb1s51TxLPnj..."
51+
52+
# Websocket Configuration
53+
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
54+
WEBSOCKET_SERVICE_URL = (
55+
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
56+
)
57+
58+
# Websocket Authentication
59+
headers = {
60+
"Content-Type": "application/json",
61+
"Authorization": f"Bearer {BEARER_TOKEN}",
62+
}
63+
64+
# Model Configuration
65+
model_path = (
66+
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
67+
)
68+
model_generation_config = {
69+
"response_modalities": ["AUDIO"],
70+
"speech_config": {
71+
"voice_config": {"prebuilt_voice_config": {"voice_name": "Aoede"}},
72+
"language_code": "es-ES",
73+
},
74+
}
75+
76+
async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
77+
# 1. Send setup configuration
78+
websocket_config = {
79+
"setup": {
80+
"model": model_path,
81+
"generation_config": model_generation_config,
82+
# Audio transcriptions for input and output
83+
"input_audio_transcription": {},
84+
"output_audio_transcription": {},
85+
}
86+
}
87+
await websocket_session.send(json.dumps(websocket_config))
88+
89+
# 2. Receive setup response
90+
raw_setup_response = await websocket_session.recv()
91+
setup_response = json.loads(
92+
raw_setup_response.decode("utf-8")
93+
if isinstance(raw_setup_response, bytes)
94+
else raw_setup_response
95+
)
96+
print(f"Setup Response: {setup_response}")
97+
# Expected response: {'setupComplete': {}}
98+
if "setupComplete" not in setup_response:
99+
print(f"Setup failed: {setup_response}")
100+
return "Error: WebSocket setup failed."
101+
102+
# 3. Send text message
103+
text_input = "Hello? Gemini are you there?"
104+
print(f"Input: {text_input}")
105+
106+
user_message = {
107+
"client_content": {
108+
"turns": [{"role": "user", "parts": [{"text": text_input}]}],
109+
"turn_complete": True,
110+
}
111+
}
112+
await websocket_session.send(json.dumps(user_message))
113+
114+
# 4. Receive model response
115+
aggregated_response_parts = []
116+
input_transcriptions_parts = []
117+
output_transcriptions_parts = []
118+
async for raw_response_chunk in websocket_session:
119+
response_chunk = json.loads(raw_response_chunk.decode("utf-8"))
120+
121+
server_content = response_chunk.get("serverContent")
122+
if not server_content:
123+
# This might indicate an error or an unexpected message format
124+
print(f"Received non-serverContent message or empty content: {response_chunk}")
125+
break
126+
127+
# Transcriptions
128+
if server_content.get("inputTranscription"):
129+
text = server_content.get("inputTranscription").get("text", "")
130+
input_transcriptions_parts.append(text)
131+
if server_content.get("outputTranscription"):
132+
text = server_content.get("outputTranscription").get("text", "")
133+
output_transcriptions_parts.append(text)
134+
135+
# Collect audio chunks
136+
model_turn = server_content.get("modelTurn")
137+
if model_turn and "parts" in model_turn and model_turn["parts"]:
138+
for part in model_turn["parts"]:
139+
if part["inlineData"]["mimeType"] == "audio/pcm":
140+
audio_chunk = base64.b64decode(part["inlineData"]["data"])
141+
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
142+
143+
# End of response
144+
if server_content.get("turnComplete"):
145+
break
146+
147+
# Save audio to a file
148+
final_response_audio = np.concatenate(aggregated_response_parts)
149+
save_audio("output.wav", 24000, final_response_audio)
150+
print(f"Input transcriptions: {''.join(input_transcriptions_parts)}")
151+
print(f"Output transcriptions: {''.join(output_transcriptions_parts)}")
152+
# Example response:
153+
# Setup Response: {'setupComplete': {}}
154+
# Input: Hello? Gemini are you there?
155+
# Audio Response(output.wav): Yes, I'm here. How can I help you today?
156+
# Input transcriptions:
157+
# Output transcriptions: Yes, I'm here. How can I help you today?
158+
# [END googlegenaisdk_live_websocket_audiotranscript_with_txt]
159+
return "output.wav"
160+
161+
162+
if __name__ == "__main__":
163+
asyncio.run(generate_content())

0 commit comments

Comments
 (0)