feat: attempt to release bf16 experts from vram
This commit is contained in:
@@ -16,6 +16,8 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
|
|||||||
params via stacked parametrizations.
|
params via stacked parametrizations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.utils.parametrize as P
|
import torch.nn.utils.parametrize as P
|
||||||
@@ -184,9 +186,17 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
replace_parameter_8bit(mod, pname)
|
replace_parameter_8bit(mod, pname)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
_moe_load_state["count"] += 1
|
_moe_load_state["count"] += 1
|
||||||
|
|
||||||
|
# Release the bf16 CUDA storage. After quantization, the
|
||||||
|
# module holds the quantized parametrization — but the
|
||||||
|
# caller's references (loop var + dict) keep the bf16 CUDA
|
||||||
|
# storage alive past empty_cache(). Replacing .data frees
|
||||||
|
# the CUDA memory immediately regardless of Python refcount.
|
||||||
|
param_value.data = torch.empty(0, device="cpu")
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
_moe_load_state["patched"] = True
|
_moe_load_state["patched"] = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user