Fix GPU OOM: share one Whisper model across calls (was leaking per call)

Calls were dropping right after answer with "CUDA failed with error out of
memory". Cause: each call constructed a new HintedWhisperSTTService -> new
ctranslate2 WhisperModel on the GPU, and that VRAM was never released when the
call ended. Over ~13 calls the python process grew to 9.7GB; with the pinned LLM
(6GB) the 16GB GPU filled (14 MiB free) and Whisper load failed on every call.

Fix: cache one WhisperModel per (model,device,compute) in _WHISPER_MODEL_CACHE
and reuse it across all calls; bake the fixed hotwords into the shared model's
transcribe() once (drops the racy per-call monkey-patch). VRAM now constant
(~6GB LLM + ~1.5GB Whisper). Verified: two instances share one model object;
GPU back to 6.0/16GB used after restart. Documented the VRAM budget.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
tocmo0nlord
2026-06-27 22:07:59 +00:00
parent ab15023651
commit a521dc168e
2 changed files with 39 additions and 19 deletions

View File

@@ -396,6 +396,16 @@ a ~3s load. Fix: `server.py` has a `lifespan` handler that warms + pins the mode
later turns are 8B generation variance. Switching Whisper size would NOT help — it's not the
bottleneck (STT model `medium` is for accuracy, not latency).
### VRAM budget — shared Whisper model (fixes OOM)
GPU is 16GB. Budget: pinned LLM ~6GB (num_ctx 8192) + **one shared** Whisper `medium` ~1.5GB +
overhead ≈ 8GB, leaving headroom. Critical: Whisper is loaded **once per process and reused
across calls** (`_WHISPER_MODEL_CACHE` in `bot.py`). Loading a new `WhisperModel` per call leaks
VRAM — ctranslate2 doesn't release it when the call ends, so models accumulated and the GPU OOM'd
after ~68 calls (`CUDA failed with error out of memory`, every call dropping right after answer).
Symptom to watch: `nvidia-smi` shows the python process growing call-over-call. Don't reintroduce
per-call model loads.
### Why Q4_K_M not Q8_0
Q8_0 consumed ~8.5GB VRAM for weights alone. Under telephony load this caused

48
bot.py
View File

@@ -424,33 +424,43 @@ class SilenceWatchdog(FrameProcessor):
await self.push_frame(frame, direction)
# One shared WhisperModel per (model, device, compute) for the whole process. Loading a new
# model per call leaks GPU memory — ctranslate2 doesn't release VRAM when the call's service is
# dropped, so models accumulate and the GPU OOMs after a handful of calls. Sharing one keeps
# VRAM constant.
_WHISPER_MODEL_CACHE = {}
class HintedWhisperSTTService(WhisperSTTService):
"""WhisperSTTService that biases transcription toward domain vocabulary via
faster-whisper `hotwords`. Pipecat's service doesn't expose hotwords, so we wrap
the model's transcribe() for the duration of each call. Each call gets its own
Whisper instance, so this per-instance patch is race-free."""
"""WhisperSTTService that shares ONE WhisperModel across all calls (avoids the per-call
GPU-memory leak/OOM) and biases transcription toward domain vocabulary via faster-whisper
`hotwords`. Hotwords are a fixed domain list, so they're baked into the shared model's
transcribe() once at load — concurrency-safe (no per-call monkey-patch)."""
def __init__(self, *args, hotwords: str | None = None, **kwargs):
self._hotwords = hotwords # set BEFORE super().__init__ (it calls _load)
super().__init__(*args, **kwargs)
self._hotwords = hotwords
async def run_stt(self, audio):
if self._hotwords and self._model is not None:
real = self._model.transcribe
def _load(self):
key = (self._settings.model, self._device, self._compute_type)
model = _WHISPER_MODEL_CACHE.get(key)
if model is None:
super()._load() # base sets self._model
model = self._model
if self._hotwords: # bake hotwords in once (value, not self)
_real = model.transcribe
_hw = self._hotwords
def patched(audio_arg, **kw):
kw.setdefault("hotwords", self._hotwords)
return real(audio_arg, **kw)
def _patched(audio_arg, **kw):
kw.setdefault("hotwords", _hw)
return _real(audio_arg, **kw)
self._model.transcribe = patched
try:
async for frame in super().run_stt(audio):
yield frame
finally:
self._model.transcribe = real
model.transcribe = _patched
_WHISPER_MODEL_CACHE[key] = model
logger.info(f"Loaded + cached shared Whisper model {key}")
else:
async for frame in super().run_stt(audio):
yield frame
logger.info(f"Reusing shared Whisper model {key}")
self._model = model
# ── TTS number normalization ──────────────────────────────────────────────────