Source code for pipecat.workers.proxy.websocket.server

#
# Copyright (c) 2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""WebSocket server proxy that receives bus messages from a remote client."""

import asyncio

from loguru import logger

from pipecat.bus import BusMessage, BusWorkerRegistryMessage
from pipecat.bus.messages import BusLocalMessage
from pipecat.bus.serializers import JSONMessageSerializer
from pipecat.bus.serializers.base import MessageSerializer
from pipecat.registry.types import WorkerReadyData, WorkerRegistryEntry
from pipecat.workers.base_worker import BaseWorker

try:
    from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
except ModuleNotFoundError as e:
    logger.error(f"Exception: {e}")
    logger.error("In order to use WebSocketProxyServer, you need to `pip install starlette`.")
    raise ImportError(f"Missing module: {e}") from e


[docs] class WebSocketProxyServer(BaseWorker): """Receives bus messages from a remote client over WebSocket. Accepts a FastAPI/Starlette WebSocket connection and forwards messages between the remote client and a local worker. Only messages from the local worker targeted at the remote worker are sent. Only inbound messages targeted at the local worker are accepted. Event handlers available: - on_client_connected: Fired when the WebSocket client connects and the proxy is ready. - on_client_disconnected: Fired when the WebSocket client disconnects. Example:: @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() proxy = WebSocketProxyServer( "gateway", websocket=websocket, worker_name="worker", remote_worker_name="voice", ) @proxy.event_handler("on_client_connected") async def on_client_connected(worker, websocket): logger.info("Client connected") @proxy.event_handler("on_client_disconnected") async def on_client_disconnected(worker, websocket): logger.info("Client disconnected") await runner.add_workers(proxy) """
[docs] def __init__( self, name: str, *, websocket: WebSocket, worker_name: str, remote_worker_name: str, forward_messages: tuple[type[BusMessage], ...] = (), serializer: MessageSerializer | None = None, ): """Initialize the WebSocketProxyServer. Args: name: Unique name for this worker. websocket: An accepted FastAPI/Starlette WebSocket connection. worker_name: Name of the local worker to route messages to/from. Only messages from this worker are forwarded to the client. remote_worker_name: Name of the worker on the remote client. Only outbound messages targeted at this worker are sent. Only inbound messages targeted at the local worker are accepted. forward_messages: Additional message types to forward from the local worker (e.g. ``(BusFrameMessage,)`` for frame routing). These are forwarded based on source worker name only, regardless of target. serializer: Serializer for bus messages. Defaults to `JSONMessageSerializer`. """ super().__init__(name) self._ws: WebSocket | None = websocket self._worker_name = worker_name self._remote_worker_name = remote_worker_name self._forward_messages = forward_messages self._serializer = serializer or JSONMessageSerializer() self._receive_task: asyncio.Task | None = None self._register_event_handler("on_client_connected") self._register_event_handler("on_client_disconnected")
[docs] async def start(self) -> None: """Start the WebSocket receive loop and watch the local worker.""" await super().start() logger.debug(f"Worker '{self}': WebSocket proxy server ready") await self._call_event_handler("on_client_connected", self._ws) self._receive_task = self.create_task(self._receive_loop()) # Schedule worker right away. await asyncio.sleep(0) # Watch the local worker so we can notify the remote side when it's ready. await self.watch_workers(self._worker_name)
[docs] async def stop(self) -> None: """Cancel the receive loop and close the WebSocket connection.""" if self._receive_task: await self.cancel_task(self._receive_task) self._receive_task = None if self._ws and self._ws.client_state == WebSocketState.CONNECTED: await self._ws.close() logger.debug(f"Worker '{self}': WebSocket connection closed") await super().stop()
[docs] async def on_worker_ready(self, data: WorkerReadyData) -> None: """Notify the remote client that the local worker is ready.""" if not self._ws: return if data.worker_name != self._worker_name: return logger.debug(f"Worker '{self}': local worker '{self._worker_name}' ready, notifying remote") try: msg = BusWorkerRegistryMessage( source=self.name, runner=data.runner, workers=[WorkerRegistryEntry(name=self._worker_name)], ) await self._send_ws(msg) except Exception: logger.exception(f"Worker '{self}': failed to send registry to remote")
[docs] async def on_bus_message(self, message: BusMessage) -> None: """Forward messages from the local worker to the remote client. Args: message: The bus message to process. """ await super().on_bus_message(message) if not self._ws: return if isinstance(message, BusLocalMessage): return if message.source != self._worker_name: return # Forward targeted messages from the local worker to the remote worker. if message.target == self._remote_worker_name: await self._send_ws(message) # Forward additional message types from the local worker. elif isinstance(message, self._forward_messages): await self._send_ws(message)
async def _send_ws(self, message: BusMessage) -> None: """Serialize and send a message over the WebSocket.""" if not self._ws: return try: data = self._serializer.serialize(message) await self._ws.send_bytes(data) logger.trace(f"Worker '{self}': sent {message}") except WebSocketDisconnect: logger.warning(f"Worker '{self}': connection closed, stopping forwarding") ws = self._ws self._ws = None await self._call_event_handler("on_client_disconnected", ws) async def _receive_loop(self) -> None: """Read messages from the WebSocket and put them on the local bus.""" assert self._ws is not None, "start() must run before _receive_loop" try: while True: data = await self._ws.receive_bytes() try: message = self._serializer.deserialize(data) if not message: continue # Accept additional message types (e.g. BusFrameMessage). if self._forward_messages and isinstance(message, self._forward_messages): logger.trace(f"Worker '{self}': received {message} from client") await self.send_bus_message(message) continue # Only accept other messages targeted at the local worker. if message.target != self._worker_name: logger.warning( f"Worker '{self}': dropped inbound message with " f"unexpected target '{message.target}'" ) continue logger.trace(f"Worker '{self}': received {message} from client") await self.send_bus_message(message) except Exception: logger.exception(f"Worker '{self}': failed to deserialize client message") except WebSocketDisconnect: logger.warning(f"Worker '{self}': client disconnected") ws = self._ws self._ws = None await self._call_event_handler("on_client_disconnected", ws) except asyncio.CancelledError: pass