diff --git a/CLAUDE.md b/CLAUDE.md index a50061c..bd3aeff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -396,6 +396,16 @@ a ~3s load. Fix: `server.py` has a `lifespan` handler that warms + pins the mode later turns are 8B generation variance. Switching Whisper size would NOT help — it's not the bottleneck (STT model `medium` is for accuracy, not latency). +### VRAM budget — shared Whisper model (fixes OOM) + +GPU is 16GB. Budget: pinned LLM ~6GB (num_ctx 8192) + **one shared** Whisper `medium` ~1.5GB + +overhead ≈ 8GB, leaving headroom. Critical: Whisper is loaded **once per process and reused +across calls** (`_WHISPER_MODEL_CACHE` in `bot.py`). Loading a new `WhisperModel` per call leaks +VRAM — ctranslate2 doesn't release it when the call ends, so models accumulated and the GPU OOM'd +after ~6–8 calls (`CUDA failed with error out of memory`, every call dropping right after answer). +Symptom to watch: `nvidia-smi` shows the python process growing call-over-call. Don't reintroduce +per-call model loads. + ### Why Q4_K_M not Q8_0 Q8_0 consumed ~8.5GB VRAM for weights alone. Under telephony load this caused diff --git a/bot.py b/bot.py index b48d268..9d03524 100644 --- a/bot.py +++ b/bot.py @@ -424,33 +424,43 @@ class SilenceWatchdog(FrameProcessor): await self.push_frame(frame, direction) +# One shared WhisperModel per (model, device, compute) for the whole process. Loading a new +# model per call leaks GPU memory — ctranslate2 doesn't release VRAM when the call's service is +# dropped, so models accumulate and the GPU OOMs after a handful of calls. Sharing one keeps +# VRAM constant. +_WHISPER_MODEL_CACHE = {} + + class HintedWhisperSTTService(WhisperSTTService): - """WhisperSTTService that biases transcription toward domain vocabulary via - faster-whisper `hotwords`. Pipecat's service doesn't expose hotwords, so we wrap - the model's transcribe() for the duration of each call. Each call gets its own - Whisper instance, so this per-instance patch is race-free.""" + """WhisperSTTService that shares ONE WhisperModel across all calls (avoids the per-call + GPU-memory leak/OOM) and biases transcription toward domain vocabulary via faster-whisper + `hotwords`. Hotwords are a fixed domain list, so they're baked into the shared model's + transcribe() once at load — concurrency-safe (no per-call monkey-patch).""" def __init__(self, *args, hotwords: str | None = None, **kwargs): + self._hotwords = hotwords # set BEFORE super().__init__ (it calls _load) super().__init__(*args, **kwargs) - self._hotwords = hotwords - async def run_stt(self, audio): - if self._hotwords and self._model is not None: - real = self._model.transcribe + def _load(self): + key = (self._settings.model, self._device, self._compute_type) + model = _WHISPER_MODEL_CACHE.get(key) + if model is None: + super()._load() # base sets self._model + model = self._model + if self._hotwords: # bake hotwords in once (value, not self) + _real = model.transcribe + _hw = self._hotwords - def patched(audio_arg, **kw): - kw.setdefault("hotwords", self._hotwords) - return real(audio_arg, **kw) + def _patched(audio_arg, **kw): + kw.setdefault("hotwords", _hw) + return _real(audio_arg, **kw) - self._model.transcribe = patched - try: - async for frame in super().run_stt(audio): - yield frame - finally: - self._model.transcribe = real + model.transcribe = _patched + _WHISPER_MODEL_CACHE[key] = model + logger.info(f"Loaded + cached shared Whisper model {key}") else: - async for frame in super().run_stt(audio): - yield frame + logger.info(f"Reusing shared Whisper model {key}") + self._model = model # ── TTS number normalization ──────────────────────────────────────────────────