#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""NVIDIA Magpie TTS service backed by an AWS SageMaker endpoint."""
import asyncio
import base64
import json
import os
from collections.abc import AsyncGenerator
from dataclasses import dataclass
import aioboto3
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterruptionFrame,
StartFrame,
TTSAudioRawFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.aws.sagemaker.bidi_client import SageMakerBidiClient
from pipecat.services.settings import TTSSettings
from pipecat.services.tts_service import InterruptibleTTSService, TTSService
from pipecat.utils.tracing.service_decorators import traced_tts
[docs]
@dataclass
class NvidiaSageMakerTTSSettings(TTSSettings):
"""Settings for NVIDIA SageMaker TTS services.
Parameters:
voice: NIM voice name (e.g. ``Magpie-Multilingual.EN-US.Aria``).
language: BCP-47 language code passed to NIM (e.g. ``en-US``).
"""
[docs]
class NvidiaSageMakerHTTPTTSService(TTSService):
"""NVIDIA Magpie TTS service that calls a SageMaker HTTP endpoint.
Sends each text segment to the wrapper's ``POST /invocations`` endpoint
as a JSON body and streams the raw PCM audio response back to bot
as :class:`TTSAudioRawFrame` frames.
Example::
tts = NvidiaSageMakerHTTPTTSService(
endpoint_name=os.getenv("SAGEMAKER_MAGPIE_ENDPOINT_NAME"),
region=os.getenv("AWS_REGION", "us-west-2"),
settings=NvidiaSageMakerHTTPTTSService.Settings(
voice="Magpie-Multilingual.EN-US.Aria",
language="en-US",
),
)
"""
Settings = NvidiaSageMakerTTSSettings
[docs]
def __init__(
self,
*,
endpoint_name: str,
region: str = "us-west-2",
sample_rate: int | None = None,
settings: NvidiaSageMakerTTSSettings | None = None,
**kwargs,
):
"""Initialize the SageMaker HTTP TTS service.
Args:
endpoint_name: Name of the deployed SageMaker endpoint.
region: AWS region where the endpoint lives.
sample_rate: Output sample rate in Hz. Defaults to bot's pipeline rate.
settings: Runtime-updatable settings (voice, language).
**kwargs: Forwarded to :class:`TTSService`.
"""
default_settings = self.Settings(
model="magpie",
voice="Magpie-Multilingual.EN-US.Aria",
language="en-US",
)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
settings=default_settings,
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client = None
self._client_ctx = None
[docs]
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as this service supports metrics generation.
"""
return True
# ── Lifecycle ─────────────────────────────────────────────────────────────
[docs]
async def start(self, frame: StartFrame):
"""Start the TTS service and create the SageMaker client.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
session = aioboto3.Session(
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
region_name=self._region,
)
self._client_ctx = session.client("sagemaker-runtime")
self._client = await self._client_ctx.__aenter__()
logger.debug(f"{self}: connected to SageMaker endpoint '{self._endpoint_name}'")
async def _close_client(self):
if self._client_ctx is not None:
try:
await self._client_ctx.__aexit__(None, None, None)
except Exception as e:
logger.warning(f"{self}: error closing SageMaker client: {e}")
self._client_ctx = None
self._client = None
[docs]
async def stop(self, frame: EndFrame):
"""Stop the TTS service and close the SageMaker client.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._close_client()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the TTS service and close the SageMaker client.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._close_client()
# ── Synthesis ─────────────────────────────────────────────────────────────
[docs]
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame, None]:
"""Synthesize text via SageMaker and yield a single PCM audio frame.
Args:
text: The text to synthesize.
context_id: Pipecat audio context identifier.
Yields:
:class:`TTSAudioRawFrame` chunks of signed 16-bit mono PCM.
"""
logger.debug(f"{self}: Generating TTS [{text}]")
text = text.strip()
if not text or not any(c.isalnum() for c in text):
return
try:
assert self._client is not None
body = json.dumps(
{
"text": text,
"voice_name": self._settings.voice,
"language_code": self._settings.language,
"sample_rate_hz": self.sample_rate,
}
)
response = await self._client.invoke_endpoint(
EndpointName=self._endpoint_name,
ContentType="application/json",
Accept="application/octet-stream",
Body=body,
)
if "Body" not in response:
yield ErrorFrame(error="SageMaker TTS returned no audio stream")
return
first_chunk = True
async for chunk in response["Body"].iter_chunks(chunk_size=self.chunk_size):
if chunk:
if first_chunk:
await self.stop_ttfb_metrics()
first_chunk = False
yield TTSAudioRawFrame(
audio=chunk,
sample_rate=self.sample_rate,
num_channels=1,
context_id=context_id,
)
except Exception as e:
logger.error(f"{self}: SageMaker TTS error: {e}")
yield ErrorFrame(error=f"SageMaker TTS error: {e}")
await self.start_tts_usage_metrics(text)
[docs]
class NvidiaSageMakerTTSService(InterruptibleTTSService):
"""NVIDIA Magpie TTS service using SageMaker bidirectional streaming.
Maintains a persistent HTTP/2 bidi-stream session to the SageMaker endpoint
for the lifetime of the pipeline. Each text segment is sent as NIM realtime
events; audio chunks arrive asynchronously and are pushed as
:class:`TTSAudioRawFrame` frames.
Example::
tts = NvidiaSageMakerTTSService(
endpoint_name=os.getenv("SAGEMAKER_MAGPIE_ENDPOINT_NAME"),
region=os.getenv("AWS_REGION", "us-west-2"),
settings=NvidiaSageMakerTTSService.Settings(
voice="Magpie-Multilingual.EN-US.Aria",
language="en-US",
),
)
"""
Settings = NvidiaSageMakerTTSSettings
[docs]
def __init__(
self,
*,
endpoint_name: str,
region: str = "us-west-2",
sample_rate: int | None = None,
settings: NvidiaSageMakerTTSSettings | None = None,
**kwargs,
):
"""Initialize the SageMaker WebSocket TTS service.
Args:
endpoint_name: Name of the deployed SageMaker endpoint.
region: AWS region where the endpoint lives.
sample_rate: Output sample rate in Hz. Defaults to pipeline rate.
settings: Runtime-updatable settings (voice, language).
**kwargs: Forwarded to :class:`InterruptibleTTSService`.
"""
default_settings = self.Settings(
model="magpie",
voice="Magpie-Multilingual.EN-US.Aria",
language="en-US",
)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
sample_rate=sample_rate,
push_start_frame=True,
push_stop_frames=True,
pause_frame_processing=True,
append_trailing_space=True,
settings=default_settings,
**kwargs,
)
self._endpoint_name = endpoint_name
self._region = region
self._client: SageMakerBidiClient | None = None
self._receive_task = None
self._speech_completed_event = asyncio.Event()
self._audio_buffer = b""
self._playback_started = False
[docs]
def can_generate_metrics(self) -> bool:
"""Check if this service can generate processing metrics.
Returns:
True, as this service supports metrics generation.
"""
return True
# ── Lifecycle ─────────────────────────────────────────────────────────────
[docs]
async def start(self, frame: StartFrame):
"""Start the TTS service and connect to the SageMaker endpoint.
Args:
frame: The start frame containing initialization parameters.
"""
await super().start(frame)
await self._connect()
[docs]
async def stop(self, frame: EndFrame):
"""Stop the TTS service and disconnect from the SageMaker endpoint.
Args:
frame: The end frame.
"""
await super().stop(frame)
await self._disconnect()
[docs]
async def cancel(self, frame: CancelFrame):
"""Cancel the TTS service and disconnect from the SageMaker endpoint.
Args:
frame: The cancel frame.
"""
await super().cancel(frame)
await self._disconnect()
# ── Connection management (WebsocketService abstract interface) ────────────
async def _connect(self):
await super()._connect()
await self._connect_websocket()
if self._client and self._client.is_active and not self._receive_task:
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
async def _disconnect(self):
await super()._disconnect()
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
await self._disconnect_websocket()
async def _connect_websocket(self):
if self._client and self._client.is_active:
return
logger.debug(
f"{self}: connecting to SageMaker bidi-stream endpoint '{self._endpoint_name}'"
)
try:
self._client = SageMakerBidiClient(
endpoint_name=self._endpoint_name,
region=self._region,
model_query_string=None,
model_invocation_path=None,
)
await self._client.start_session()
await self._send_session_config()
logger.debug(f"{self}: connected")
await self._call_event_handler("on_connected")
except Exception as e:
logger.error(f"{self}: connection error: {e}")
self._client = None
await self._call_event_handler("on_connection_error", f"{e}")
async def _disconnect_websocket(self):
try:
if self._client and self._client.is_active:
logger.debug(f"{self}: disconnecting")
try:
await self._client.send_json({"type": "session.end"})
except Exception as e:
logger.warning(f"{self}: error sending session.end: {e}")
await self._client.close_session()
logger.debug(f"{self}: disconnected")
except Exception as e:
logger.warning(f"{self}: error during disconnect: {e}")
finally:
self._client = None
await self._call_event_handler("on_disconnected")
async def _verify_connection(self):
active = self._client and self._client.is_active
logger.info(f"{self}: verifying if websocket connection is active {active}")
return active
def _reset_audio_buffer(self):
self._audio_buffer = b""
self._playback_started = False
async def _handle_interruption(self, frame: InterruptionFrame, direction: FrameDirection):
self._reset_audio_buffer()
if self._bot_speaking and self._client:
logger.debug(
f"{self}: interruption detected, sending input_text.done and waiting for speech.completed"
)
self._disconnecting = True
self._speech_completed_event.clear()
try:
await self._client.send_json({"type": "input_text.done"})
await asyncio.wait_for(self._speech_completed_event.wait(), timeout=5.0)
except TimeoutError:
logger.warning(f"{self}: timed out waiting for conversation.item.speech.completed")
await super()._handle_interruption(frame, direction)
async def _handle_audio_chunk(self, audio: bytes, context_id: str | None = None):
"""Buffer audio and emit frames using a jitter-buffer approach.
Holds back audio until chunk_size bytes have been accumulated (to avoid
glitches at the start of playback), then emits each subsequent chunk
immediately as it arrives.
"""
self._audio_buffer += audio
if not self._playback_started:
if len(self._audio_buffer) < self.chunk_size:
return
self._playback_started = True
await self.push_frame(
TTSAudioRawFrame(
audio=self._audio_buffer,
sample_rate=self.sample_rate,
num_channels=1,
context_id=context_id,
)
)
self._audio_buffer = b""
async def _receive_messages(self):
"""Receive NIM JSON events and push audio frames."""
while self._client and self._client.is_active and not self._disconnecting:
result = await self._client.receive_response()
if self._disconnecting:
self._speech_completed_event.set()
if result is None:
break
if not (hasattr(result, "value") and hasattr(result.value, "bytes_")): # type: ignore[union-attr]
continue
payload = result.value.bytes_ # type: ignore[union-attr]
if not payload:
continue
context_id = self.get_active_audio_context_id()
try:
msg = json.loads(payload.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError):
# Unexpected binary frame — treat as raw PCM
await self._handle_audio_chunk(payload, context_id)
continue
event_type = msg.get("type", "")
if event_type != "conversation.item.speech.data":
logger.debug(f"{self}: received event: {event_type}")
if event_type == "conversation.item.speech.data":
chunk_b64 = msg.get("audio", "")
if chunk_b64:
await self.stop_ttfb_metrics()
await self._handle_audio_chunk(base64.b64decode(chunk_b64), context_id)
elif event_type == "error":
await self.push_error(error_msg=f"NIM error: {msg.get('message', msg)}")
# In case of error we need to reconnect, otherwise we are not going to receive audio from the TTS service anymore
break
elif event_type == "conversation.item.speech.completed":
# Need to reconnect to reset the synthesis state and be able to synthesize new text
break
# synthesize_session.updated, input_text.committed, etc. are ignored.
async def _send_session_config(self):
"""Send synthesize_session.update to configure voice and audio params."""
logger.debug(f"{self}: sending session config, sample_rate={self.sample_rate}")
assert self._client is not None
await self._client.send_json(
{
"type": "synthesize_session.update",
"session": {
"input_text_synthesis": {
"voice_name": self._settings.voice,
"language_code": self._settings.language,
},
"output_audio_params": {
"sample_rate_hz": self.sample_rate,
},
},
}
)
# ── Synthesis ─────────────────────────────────────────────────────────────
[docs]
@traced_tts
async def run_tts(self, text: str, context_id: str) -> AsyncGenerator[Frame | None, None]:
"""Send text to NIM; audio arrives asynchronously via _receive_messages."""
logger.debug(f"{self}: Generating TTS [{text}]")
text = text.strip()
if not text or not any(c.isalnum() for c in text):
return
try:
if not self._client or not self._client.is_active:
await self._connect()
assert self._client is not None
await self._client.send_json({"type": "input_text.append", "text": text})
await self._client.send_json({"type": "input_text.commit"})
await self.start_tts_usage_metrics(text)
yield None
except Exception as e:
logger.error(f"{self}: TTS error: {e}")
yield ErrorFrame(error=f"NvidiaSageMakerTTSService error: {e}")