From 21b2dfef2da46a67a4fa8c88d722ad7d0ebe92b7 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 26 Feb 2026 16:19:07 +0700 Subject: [PATCH] feat: attempt to release bf16 experts from vram --- src/axolotl/monkeypatch/moe_quant.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 360e4c571..c4dc760e4 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -16,6 +16,8 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan params via stacked parametrizations. """ +import gc + import bitsandbytes as bnb import torch import torch.nn.utils.parametrize as P @@ -184,9 +186,17 @@ def patch_moe_quantization_on_load(cfg): ) else: replace_parameter_8bit(mod, pname) - torch.cuda.empty_cache() _moe_load_state["count"] += 1 + # Release the bf16 CUDA storage. After quantization, the + # module holds the quantized parametrization — but the + # caller's references (loop var + dict) keep the bf16 CUDA + # storage alive past empty_cache(). Replacing .data frees + # the CUDA memory immediately regardless of Python refcount. + param_value.data = torch.empty(0, device="cpu") + gc.collect() + torch.cuda.empty_cache() + transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module _moe_load_state["patched"] = True