Source code for pipecat.audio.turn.smart_turn.local_smart_turn_v3

#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

"""Local turn analyzer for on-device ML inference using the smart-turn-v3 model.

This module provides a smart turn analyzer that uses an ONNX model for
local end-of-turn detection without requiring network connectivity.
"""

from typing import Any

import numpy as np
import onnxruntime as ort
import soxr
from loguru import logger

from pipecat.audio.turn.smart_turn._whisper_features import compute_whisper_log_mel_features
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
from pipecat.utils.env import env_truthy

# The Whisper-based ONNX model expects 16 kHz audio input.
_MODEL_SAMPLE_RATE = 16000


[docs] class LocalSmartTurnAnalyzerV3(BaseSmartTurn): """Local turn analyzer using the smart-turn-v3 ONNX model. Provides end-of-turn detection using locally-stored ONNX model, enabling offline operation without network dependencies. """
[docs] def __init__(self, *, smart_turn_model_path: str | None = None, cpu_count: int = 1, **kwargs): """Initialize the local ONNX smart-turn-v3 analyzer. Args: smart_turn_model_path: Path to the ONNX model file. If this is not set, the bundled smart-turn-v3.2-cpu model will be used. cpu_count: The number of CPUs to use for inference. Defaults to 1. **kwargs: Additional arguments passed to BaseSmartTurn. """ super().__init__(**kwargs) self._log_data = env_truthy("PIPECAT_SMART_TURN_LOG_DATA", default=False) if not smart_turn_model_path: # Load bundled model model_name = "smart-turn-v3.2-cpu.onnx" package_path = "pipecat.audio.turn.smart_turn.data" try: import importlib_resources as impresources smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name)) except BaseException: from importlib import resources as impresources try: with impresources.path(package_path, model_name) as f: smart_turn_model_path = f except BaseException: smart_turn_model_path = str( impresources.files(package_path).joinpath(model_name) ) logger.debug(f"Loading Local Smart Turn v3.x model from {smart_turn_model_path}...") so = ort.SessionOptions() so.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL so.inter_op_num_threads = 1 so.intra_op_num_threads = cpu_count so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so) logger.debug("Loaded Local Smart Turn v3.x")
def _write_audio_to_wav( self, audio_array: np.ndarray, sample_rate: int = _MODEL_SAMPLE_RATE, suffix: str = "" ) -> None: """Write audio data to a WAV file in a background thread. Args: audio_array: The audio data as a numpy array (float32, normalized to [-1, 1]). sample_rate: The sample rate of the audio data. suffix: Optional suffix to append to the filename (e.g., "_raw", "_padded"). """ import os import threading import wave from datetime import datetime # Generate filename with current timestamp (millisecond precision) timestamp = datetime.now().strftime("%Y-%m-%d__%H:%M:%S.%f")[:-3] log_dir = "./smart_turn_audio_log" os.makedirs(log_dir, exist_ok=True) filename = os.path.join(log_dir, f"{timestamp}{suffix}.wav") # Make a copy of the audio data to avoid issues with the array being modified audio_copy = audio_array.copy() def write_wav(): try: # Convert float32 audio to int16 for WAV file audio_int16 = (audio_copy * 32767).astype(np.int16) with wave.open(filename, "wb") as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 2 bytes for int16 wav_file.setframerate(sample_rate) wav_file.writeframes(audio_int16.tobytes()) logger.debug(f"Wrote audio to {filename}") except Exception as e: logger.error(f"Failed to write audio to {filename}: {e}") # Start background thread to write the WAV file thread = threading.Thread(target=write_wav, daemon=True) thread.start() def _resample_to_model_rate(self, audio_array: np.ndarray) -> np.ndarray: """Resample audio to the model's expected sample rate (16 kHz). Args: audio_array: Audio data as a float32 numpy array. Returns: Resampled audio array at 16 kHz. """ actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE if actual_rate == _MODEL_SAMPLE_RATE: return audio_array return soxr.resample(audio_array, actual_rate, _MODEL_SAMPLE_RATE, quality="HQ") def _predict_endpoint(self, audio_array: np.ndarray) -> dict[str, Any]: """Predict end-of-turn using local ONNX model.""" def truncate_audio_to_last_n_seconds( audio_array, n_seconds=8, sample_rate=_MODEL_SAMPLE_RATE ): """Truncate audio to last n seconds or pad with zeros to meet n seconds.""" max_samples = n_seconds * sample_rate if len(audio_array) > max_samples: return audio_array[-max_samples:] elif len(audio_array) < max_samples: # Pad with zeros at the beginning padding = max_samples - len(audio_array) return np.pad(audio_array, (padding, 0), mode="constant", constant_values=0) return audio_array audio_for_logging = audio_array actual_rate = self._sample_rate or _MODEL_SAMPLE_RATE # Resample to 16 kHz if the pipeline uses a different sample rate audio_array = self._resample_to_model_rate(audio_array) # Truncate to 8 seconds (keeping the end) or pad to 8 seconds audio_array = truncate_audio_to_last_n_seconds(audio_array, n_seconds=8) # Compute Whisper-style log-mel features (vendored numpy implementation). log_mel = compute_whisper_log_mel_features(audio_array, do_normalize=True) input_features = np.expand_dims(log_mel, axis=0) # Add batch dimension # Run ONNX inference outputs = self._session.run(None, {"input_features": input_features}) # Extract probability (ONNX model returns sigmoid probabilities) probability = outputs[0][0].item() # Make prediction (1 for Complete, 0 for Incomplete) prediction = 1 if probability > 0.5 else 0 if self._log_data: suffix = "_complete" if prediction == 1 else "_incomplete" self._write_audio_to_wav(audio_for_logging, sample_rate=actual_rate, suffix=suffix) return { "prediction": prediction, "probability": probability, }