Source code for pipecat.services.aws.utils

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""AWS utility functions for Pipecat services.

This module provides shared credential resolution and AWS Transcribe utilities
for creating presigned URLs, building event messages, and handling AWS event
stream protocol for real-time transcription services.
"""

import binascii
import datetime
import hashlib
import hmac
import json
import os
import struct
import urllib.parse
from dataclasses import dataclass
from typing import Any

from loguru import logger


[docs] @dataclass class AWSCredentials: """Resolved AWS credentials ready for use by any AWS service.""" access_key: str | None secret_key: str | None session_token: str | None region: str
[docs] def to_boto_kwargs(self) -> dict[str, str | None]: """Return credentials as kwargs accepted by ``boto3``/``aioboto3`` clients.""" return { "aws_access_key_id": self.access_key, "aws_secret_access_key": self.secret_key, "aws_session_token": self.session_token, "region_name": self.region, }
[docs] def resolve_credentials( *, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_session_token: str | None = None, region: str | None = None, ) -> AWSCredentials: """Resolve AWS credentials using the standard fallback chain. Resolution order: 1. Explicit parameters 2. Environment variables (``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, ``AWS_SESSION_TOKEN``, ``AWS_REGION``) 3. Default boto3/botocore credential chain (instance profiles, IRSA, ECS task roles, SSO, credential files, etc.) The boto3 fallback (step 3) is only attempted when *both* access key and secret key are still unresolved after steps 1-2. This avoids replacing explicitly provided credentials with ambient ones. Args: aws_access_key_id: Explicit access key ID. aws_secret_access_key: Explicit secret access key. aws_session_token: Explicit session token. region: Explicit AWS region. Returns: An :class:`AWSCredentials` instance. ``access_key`` and ``secret_key`` may still be ``None`` if no credentials could be resolved (the caller should raise an appropriate error). """ access_key = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") secret_key = aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY") session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN") resolved_region = region or os.getenv("AWS_REGION", "us-east-1") # Fall back to the boto3 credential provider chain (pod roles, IRSA, # instance profiles, SSO, credential files, etc.) when explicit # credentials were not supplied. if not access_key and not secret_key: try: import boto3 session = boto3.Session(region_name=resolved_region) creds = session.get_credentials() if creds: frozen = creds.get_frozen_credentials() access_key = access_key or frozen.access_key secret_key = secret_key or frozen.secret_key session_token = session_token or frozen.token except ImportError: logger.debug( "boto3 not available for credential chain fallback; " "install pipecat-ai[aws] for full credential support." ) except Exception as e: logger.warning(f"Failed to resolve AWS credentials via boto3 chain: {e}") return AWSCredentials( access_key=access_key, secret_key=secret_key, session_token=session_token, region=resolved_region, )
[docs] def get_presigned_url( *, region: str, credentials: dict[str, str | None], language_code: str, media_encoding: str = "pcm", sample_rate: int = 16000, number_of_channels: int = 1, enable_partial_results_stabilization: bool = True, partial_results_stability: str = "high", vocabulary_name: str | None = None, vocabulary_filter_name: str | None = None, show_speaker_label: bool = False, enable_channel_identification: bool = False, ) -> str: """Create a presigned URL for AWS Transcribe streaming. Args: region: AWS region for the service. credentials: Dictionary containing AWS credentials. Must include 'access_key' and 'secret_key', with optional 'session_token'. language_code: Language code for transcription (e.g., "en-US"). media_encoding: Audio encoding format. Defaults to "pcm". sample_rate: Audio sample rate in Hz. Defaults to 16000. number_of_channels: Number of audio channels. Defaults to 1. enable_partial_results_stabilization: Whether to enable partial result stabilization. partial_results_stability: Stability level for partial results. vocabulary_name: Custom vocabulary name to use. vocabulary_filter_name: Vocabulary filter name to apply. show_speaker_label: Whether to include speaker labels. enable_channel_identification: Whether to enable channel identification. Returns: Presigned WebSocket URL for AWS Transcribe streaming. Raises: ValueError: If required AWS credentials are missing. """ access_key = credentials.get("access_key") secret_key = credentials.get("secret_key") session_token = credentials.get("session_token") if not access_key or not secret_key: raise ValueError("AWS credentials are required") # Initialize the URL generator url_generator = AWSTranscribePresignedURL( access_key=access_key, secret_key=secret_key, session_token=session_token, region=region ) # Get the presigned URL return url_generator.get_request_url( sample_rate=sample_rate, language_code=language_code, media_encoding=media_encoding, vocabulary_name=vocabulary_name, vocabulary_filter_name=vocabulary_filter_name, show_speaker_label=show_speaker_label, enable_channel_identification=enable_channel_identification, number_of_channels=number_of_channels, enable_partial_results_stabilization=enable_partial_results_stabilization, partial_results_stability=partial_results_stability, )
[docs] class AWSTranscribePresignedURL: """Generator for AWS Transcribe presigned WebSocket URLs. Handles AWS Signature Version 4 signing process to create authenticated WebSocket URLs for streaming transcription requests. """
[docs] def __init__( self, access_key: str, secret_key: str, session_token: str | None, region: str = "us-east-1", ): """Initialize the presigned URL generator. Args: access_key: AWS access key ID. secret_key: AWS secret access key. session_token: AWS session token for temporary credentials (optional). region: AWS region for the service. Defaults to "us-east-1". """ self.access_key = access_key self.secret_key = secret_key self.session_token = session_token self.method = "GET" self.service = "transcribe" self.region = region self.endpoint = "" self.host = "" self.amz_date = "" self.datestamp = "" self.canonical_uri = "/stream-transcription-websocket" self.canonical_headers = "" self.signed_headers = "host" self.algorithm = "AWS4-HMAC-SHA256" self.credential_scope = "" self.canonical_querystring = "" self.payload_hash = "" self.canonical_request = "" self.string_to_sign = "" self.signature = "" self.request_url = ""
[docs] def get_request_url( self, sample_rate: int, language_code: str = "", media_encoding: str = "pcm", vocabulary_name: str | None = None, vocabulary_filter_name: str | None = None, show_speaker_label: bool = False, enable_channel_identification: bool = False, number_of_channels: int = 1, enable_partial_results_stabilization: bool = False, partial_results_stability: str = "", ) -> str: """Generate a presigned WebSocket URL for AWS Transcribe. Args: sample_rate: Audio sample rate in Hz. language_code: Language code for transcription. media_encoding: Audio encoding format. vocabulary_name: Custom vocabulary name. vocabulary_filter_name: Vocabulary filter name. show_speaker_label: Whether to include speaker labels. enable_channel_identification: Whether to enable channel identification. number_of_channels: Number of audio channels. enable_partial_results_stabilization: Whether to enable partial result stabilization. partial_results_stability: Stability level for partial results. Returns: Presigned WebSocket URL with authentication parameters. """ self.endpoint = f"wss://transcribestreaming.{self.region}.amazonaws.com:8443" self.host = f"transcribestreaming.{self.region}.amazonaws.com:8443" now = datetime.datetime.utcnow() self.amz_date = now.strftime("%Y%m%dT%H%M%SZ") self.datestamp = now.strftime("%Y%m%d") self.canonical_headers = f"host:{self.host}\n" self.credential_scope = f"{self.datestamp}%2F{self.region}%2F{self.service}%2Faws4_request" # Create canonical querystring self.canonical_querystring = "X-Amz-Algorithm=" + self.algorithm self.canonical_querystring += ( "&X-Amz-Credential=" + self.access_key + "%2F" + self.credential_scope ) self.canonical_querystring += "&X-Amz-Date=" + self.amz_date self.canonical_querystring += "&X-Amz-Expires=300" if self.session_token: self.canonical_querystring += "&X-Amz-Security-Token=" + urllib.parse.quote( self.session_token, safe="" ) self.canonical_querystring += "&X-Amz-SignedHeaders=" + self.signed_headers if enable_channel_identification: self.canonical_querystring += "&enable-channel-identification=true" if enable_partial_results_stabilization: self.canonical_querystring += "&enable-partial-results-stabilization=true" if language_code: self.canonical_querystring += "&language-code=" + language_code if media_encoding: self.canonical_querystring += "&media-encoding=" + media_encoding if number_of_channels > 1: self.canonical_querystring += "&number-of-channels=" + str(number_of_channels) if partial_results_stability: self.canonical_querystring += "&partial-results-stability=" + partial_results_stability if sample_rate: self.canonical_querystring += "&sample-rate=" + str(sample_rate) if show_speaker_label: self.canonical_querystring += "&show-speaker-label=true" if vocabulary_filter_name: self.canonical_querystring += "&vocabulary-filter-name=" + vocabulary_filter_name if vocabulary_name: self.canonical_querystring += "&vocabulary-name=" + vocabulary_name # Create payload hash self.payload_hash = hashlib.sha256(b"").hexdigest() # Create canonical request self.canonical_request = f"{self.method}\n{self.canonical_uri}\n{self.canonical_querystring}\n{self.canonical_headers}\n{self.signed_headers}\n{self.payload_hash}" # Create string to sign credential_scope = f"{self.datestamp}/{self.region}/{self.service}/aws4_request" string_to_sign = ( f"{self.algorithm}\n{self.amz_date}\n{credential_scope}\n" + hashlib.sha256(self.canonical_request.encode("utf-8")).hexdigest() ) # Calculate signature k_date = hmac.new( f"AWS4{self.secret_key}".encode(), self.datestamp.encode("utf-8"), hashlib.sha256 ).digest() k_region = hmac.new(k_date, self.region.encode("utf-8"), hashlib.sha256).digest() k_service = hmac.new(k_region, self.service.encode("utf-8"), hashlib.sha256).digest() k_signing = hmac.new(k_service, b"aws4_request", hashlib.sha256).digest() self.signature = hmac.new( k_signing, string_to_sign.encode("utf-8"), hashlib.sha256 ).hexdigest() # Add signature to query string self.canonical_querystring += "&X-Amz-Signature=" + self.signature # Create request URL self.request_url = self.endpoint + self.canonical_uri + "?" + self.canonical_querystring return self.request_url
[docs] def get_headers(header_name: str, header_value: str) -> bytearray: """Build a header following AWS event stream format. Args: header_name: Name of the header. header_value: Value of the header. Returns: Encoded header as a bytearray following AWS event stream protocol. """ name = header_name.encode("utf-8") name_byte_length = bytes([len(name)]) value_type = bytes([7]) # 7 represents a string value = header_value.encode("utf-8") value_byte_length = struct.pack(">H", len(value)) # Construct the header header_list = bytearray() header_list.extend(name_byte_length) header_list.extend(name) header_list.extend(value_type) header_list.extend(value_byte_length) header_list.extend(value) return header_list
[docs] def build_event_message(payload: bytes) -> bytes: """Build an event message for AWS Transcribe streaming. Creates a properly formatted AWS event stream message containing audio data for real-time transcription. Follows the AWS event stream protocol with prelude, headers, payload, and CRC checksums. Args: payload: Raw audio bytes to include in the event message. Returns: Complete event message as bytes, ready to send via WebSocket. Note: Implementation matches AWS sample: https://github.com/aws-samples/amazon-transcribe-streaming-python-websockets/blob/main/eventstream.py """ # Build headers content_type_header = get_headers(":content-type", "application/octet-stream") event_type_header = get_headers(":event-type", "AudioEvent") message_type_header = get_headers(":message-type", "event") headers = bytearray() headers.extend(content_type_header) headers.extend(event_type_header) headers.extend(message_type_header) # Calculate total byte length and headers byte length # 16 accounts for 8 byte prelude, 2x 4 byte CRCs total_byte_length = struct.pack(">I", len(headers) + len(payload) + 16) headers_byte_length = struct.pack(">I", len(headers)) # Build the prelude prelude = bytearray([0] * 8) prelude[:4] = total_byte_length prelude[4:] = headers_byte_length # Calculate checksum for prelude prelude_crc = struct.pack(">I", binascii.crc32(prelude) & 0xFFFFFFFF) # Construct the message message_as_list = bytearray() message_as_list.extend(prelude) message_as_list.extend(prelude_crc) message_as_list.extend(headers) message_as_list.extend(payload) # Calculate checksum for message message = bytes(message_as_list) message_crc = struct.pack(">I", binascii.crc32(message) & 0xFFFFFFFF) # Add message checksum message_as_list.extend(message_crc) return bytes(message_as_list)
[docs] def decode_event(message): """Decode an AWS event stream message. Parses an AWS event stream message to extract headers and payload, verifying CRC checksums for data integrity. Args: message: Raw event stream message bytes received from AWS. Returns: A tuple of (headers, payload) where: - headers: Dictionary of parsed headers - payload: Dictionary of parsed JSON payload Raises: AssertionError: If CRC checksum verification fails. """ # Extract the prelude, headers, payload and CRC prelude = message[:8] total_length, headers_length = struct.unpack(">II", prelude) prelude_crc = struct.unpack(">I", message[8:12])[0] headers = message[12 : 12 + headers_length] payload = message[12 + headers_length : -4] message_crc = struct.unpack(">I", message[-4:])[0] # Check the CRCs assert prelude_crc == binascii.crc32(prelude) & 0xFFFFFFFF, "Prelude CRC check failed" assert message_crc == binascii.crc32(message[:-4]) & 0xFFFFFFFF, "Message CRC check failed" # Parse the headers headers_dict = {} while headers: name_len = headers[0] name = headers[1 : 1 + name_len].decode("utf-8") value_type = headers[1 + name_len] value_len = struct.unpack(">H", headers[2 + name_len : 4 + name_len])[0] value = headers[4 + name_len : 4 + name_len + value_len].decode("utf-8") headers_dict[name] = value headers = headers[4 + name_len + value_len :] return headers_dict, json.loads(payload)