feat: attempt to release bf16 experts from vram
This commit is contained in:
@@ -16,6 +16,8 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
|
||||
params via stacked parametrizations.
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as P
|
||||
@@ -184,9 +186,17 @@ def patch_moe_quantization_on_load(cfg):
|
||||
)
|
||||
else:
|
||||
replace_parameter_8bit(mod, pname)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
|
||||
# Release the bf16 CUDA storage. After quantization, the
|
||||
# module holds the quantized parametrization — but the
|
||||
# caller's references (loop var + dict) keep the bf16 CUDA
|
||||
# storage alive past empty_cache(). Replacing .data frees
|
||||
# the CUDA memory immediately regardless of Python refcount.
|
||||
param_value.data = torch.empty(0, device="cpu")
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user