fix: remove unnecessary gc and empty cache
This commit is contained in:
@@ -16,8 +16,6 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
|
||||
params via stacked parametrizations.
|
||||
"""
|
||||
|
||||
import gc
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.utils.parametrize as P
|
||||
@@ -188,14 +186,16 @@ def patch_moe_quantization_on_load(cfg):
|
||||
replace_parameter_8bit(mod, pname)
|
||||
_moe_load_state["count"] += 1
|
||||
|
||||
# Release the bf16 CUDA storage. After quantization, the
|
||||
# module holds the quantized parametrization — but the
|
||||
# caller's references (loop var + dict) keep the bf16 CUDA
|
||||
# storage alive past empty_cache(). Replacing .data frees
|
||||
# the CUDA memory immediately regardless of Python refcount.
|
||||
# Release the bf16 CUDA storage immediately. After
|
||||
# quantization, the module holds the quantized
|
||||
# parametrization, but the caller's references (loop var
|
||||
# + realized_value dict) keep the bf16 storage alive.
|
||||
# Replacing .data frees it regardless of Python refcount.
|
||||
# We intentionally skip empty_cache() here so the CUDA
|
||||
# caching allocator can reuse the freed block for the
|
||||
# next expert param (same size). Cleanup happens after
|
||||
# loading completes (model.py:202 and model.py:456-458).
|
||||
param_value.data = torch.empty(0, device="cpu")
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
|
||||
Reference in New Issue
Block a user