Fix GPU OOM: share one Whisper model across calls (was leaking per call)
Calls were dropping right after answer with "CUDA failed with error out of memory". Cause: each call constructed a new HintedWhisperSTTService -> new ctranslate2 WhisperModel on the GPU, and that VRAM was never released when the call ended. Over ~13 calls the python process grew to 9.7GB; with the pinned LLM (6GB) the 16GB GPU filled (14 MiB free) and Whisper load failed on every call. Fix: cache one WhisperModel per (model,device,compute) in _WHISPER_MODEL_CACHE and reuse it across all calls; bake the fixed hotwords into the shared model's transcribe() once (drops the racy per-call monkey-patch). VRAM now constant (~6GB LLM + ~1.5GB Whisper). Verified: two instances share one model object; GPU back to 6.0/16GB used after restart. Documented the VRAM budget. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
10
CLAUDE.md
10
CLAUDE.md
@@ -396,6 +396,16 @@ a ~3s load. Fix: `server.py` has a `lifespan` handler that warms + pins the mode
|
||||
later turns are 8B generation variance. Switching Whisper size would NOT help — it's not the
|
||||
bottleneck (STT model `medium` is for accuracy, not latency).
|
||||
|
||||
### VRAM budget — shared Whisper model (fixes OOM)
|
||||
|
||||
GPU is 16GB. Budget: pinned LLM ~6GB (num_ctx 8192) + **one shared** Whisper `medium` ~1.5GB +
|
||||
overhead ≈ 8GB, leaving headroom. Critical: Whisper is loaded **once per process and reused
|
||||
across calls** (`_WHISPER_MODEL_CACHE` in `bot.py`). Loading a new `WhisperModel` per call leaks
|
||||
VRAM — ctranslate2 doesn't release it when the call ends, so models accumulated and the GPU OOM'd
|
||||
after ~6–8 calls (`CUDA failed with error out of memory`, every call dropping right after answer).
|
||||
Symptom to watch: `nvidia-smi` shows the python process growing call-over-call. Don't reintroduce
|
||||
per-call model loads.
|
||||
|
||||
### Why Q4_K_M not Q8_0
|
||||
|
||||
Q8_0 consumed ~8.5GB VRAM for weights alone. Under telephony load this caused
|
||||
|
||||
48
bot.py
48
bot.py
@@ -424,33 +424,43 @@ class SilenceWatchdog(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# One shared WhisperModel per (model, device, compute) for the whole process. Loading a new
|
||||
# model per call leaks GPU memory — ctranslate2 doesn't release VRAM when the call's service is
|
||||
# dropped, so models accumulate and the GPU OOMs after a handful of calls. Sharing one keeps
|
||||
# VRAM constant.
|
||||
_WHISPER_MODEL_CACHE = {}
|
||||
|
||||
|
||||
class HintedWhisperSTTService(WhisperSTTService):
|
||||
"""WhisperSTTService that biases transcription toward domain vocabulary via
|
||||
faster-whisper `hotwords`. Pipecat's service doesn't expose hotwords, so we wrap
|
||||
the model's transcribe() for the duration of each call. Each call gets its own
|
||||
Whisper instance, so this per-instance patch is race-free."""
|
||||
"""WhisperSTTService that shares ONE WhisperModel across all calls (avoids the per-call
|
||||
GPU-memory leak/OOM) and biases transcription toward domain vocabulary via faster-whisper
|
||||
`hotwords`. Hotwords are a fixed domain list, so they're baked into the shared model's
|
||||
transcribe() once at load — concurrency-safe (no per-call monkey-patch)."""
|
||||
|
||||
def __init__(self, *args, hotwords: str | None = None, **kwargs):
|
||||
self._hotwords = hotwords # set BEFORE super().__init__ (it calls _load)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hotwords = hotwords
|
||||
|
||||
async def run_stt(self, audio):
|
||||
if self._hotwords and self._model is not None:
|
||||
real = self._model.transcribe
|
||||
def _load(self):
|
||||
key = (self._settings.model, self._device, self._compute_type)
|
||||
model = _WHISPER_MODEL_CACHE.get(key)
|
||||
if model is None:
|
||||
super()._load() # base sets self._model
|
||||
model = self._model
|
||||
if self._hotwords: # bake hotwords in once (value, not self)
|
||||
_real = model.transcribe
|
||||
_hw = self._hotwords
|
||||
|
||||
def patched(audio_arg, **kw):
|
||||
kw.setdefault("hotwords", self._hotwords)
|
||||
return real(audio_arg, **kw)
|
||||
def _patched(audio_arg, **kw):
|
||||
kw.setdefault("hotwords", _hw)
|
||||
return _real(audio_arg, **kw)
|
||||
|
||||
self._model.transcribe = patched
|
||||
try:
|
||||
async for frame in super().run_stt(audio):
|
||||
yield frame
|
||||
finally:
|
||||
self._model.transcribe = real
|
||||
model.transcribe = _patched
|
||||
_WHISPER_MODEL_CACHE[key] = model
|
||||
logger.info(f"Loaded + cached shared Whisper model {key}")
|
||||
else:
|
||||
async for frame in super().run_stt(audio):
|
||||
yield frame
|
||||
logger.info(f"Reusing shared Whisper model {key}")
|
||||
self._model = model
|
||||
|
||||
|
||||
# ── TTS number normalization ──────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user