feat: attempt to release bf16 experts from vram

This commit is contained in:
NanoCode012
2026-02-26 16:19:07 +07:00
parent f68d9f839d
commit 21b2dfef2d

View File

@@ -16,6 +16,8 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
params via stacked parametrizations. params via stacked parametrizations.
""" """
import gc
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import torch.nn.utils.parametrize as P import torch.nn.utils.parametrize as P
@@ -184,9 +186,17 @@ def patch_moe_quantization_on_load(cfg):
) )
else: else:
replace_parameter_8bit(mod, pname) replace_parameter_8bit(mod, pname)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1 _moe_load_state["count"] += 1
# Release the bf16 CUDA storage. After quantization, the
# module holds the quantized parametrization — but the
# caller's references (loop var + dict) keep the bf16 CUDA
# storage alive past empty_cache(). Replacing .data frees
# the CUDA memory immediately regardless of Python refcount.
param_value.data = torch.empty(0, device="cpu")
gc.collect()
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True _moe_load_state["patched"] = True