feat: attempt to release bf16 experts from vram

This commit is contained in:
NanoCode012
2026-02-26 16:19:07 +07:00
parent f68d9f839d
commit 21b2dfef2d

View File

@@ -16,6 +16,8 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
params via stacked parametrizations.
"""
import gc
import bitsandbytes as bnb
import torch
import torch.nn.utils.parametrize as P
@@ -184,9 +186,17 @@ def patch_moe_quantization_on_load(cfg):
)
else:
replace_parameter_8bit(mod, pname)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
# Release the bf16 CUDA storage. After quantization, the
# module holds the quantized parametrization — but the
# caller's references (loop var + dict) keep the bf16 CUDA
# storage alive past empty_cache(). Replacing .data frees
# the CUDA memory immediately regardless of Python refcount.
param_value.data = torch.empty(0, device="cpu")
gc.collect()
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True