diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index c4dc760e4..8c5e31d8d 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -16,8 +16,6 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan params via stacked parametrizations. """ -import gc - import bitsandbytes as bnb import torch import torch.nn.utils.parametrize as P @@ -188,14 +186,16 @@ def patch_moe_quantization_on_load(cfg): replace_parameter_8bit(mod, pname) _moe_load_state["count"] += 1 - # Release the bf16 CUDA storage. After quantization, the - # module holds the quantized parametrization — but the - # caller's references (loop var + dict) keep the bf16 CUDA - # storage alive past empty_cache(). Replacing .data frees - # the CUDA memory immediately regardless of Python refcount. + # Release the bf16 CUDA storage immediately. After + # quantization, the module holds the quantized + # parametrization, but the caller's references (loop var + # + realized_value dict) keep the bf16 storage alive. + # Replacing .data frees it regardless of Python refcount. + # We intentionally skip empty_cache() here so the CUDA + # caching allocator can reuse the freed block for the + # next expert param (same size). Cleanup happens after + # loading completes (model.py:202 and model.py:456-458). param_value.data = torch.empty(0, device="cpu") - gc.collect() - torch.cuda.empty_cache() transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module _moe_load_state["patched"] = True