fix: remove unnecessary gc and empty cache

This commit is contained in:
NanoCode012
2026-02-26 19:18:29 +07:00
parent e0eed7542d
commit 1d54518990

View File

@@ -16,8 +16,6 @@ PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quan
params via stacked parametrizations. params via stacked parametrizations.
""" """
import gc
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import torch.nn.utils.parametrize as P import torch.nn.utils.parametrize as P
@@ -188,14 +186,16 @@ def patch_moe_quantization_on_load(cfg):
replace_parameter_8bit(mod, pname) replace_parameter_8bit(mod, pname)
_moe_load_state["count"] += 1 _moe_load_state["count"] += 1
# Release the bf16 CUDA storage. After quantization, the # Release the bf16 CUDA storage immediately. After
# module holds the quantized parametrization — but the # quantization, the module holds the quantized
# caller's references (loop var + dict) keep the bf16 CUDA # parametrization, but the caller's references (loop var
# storage alive past empty_cache(). Replacing .data frees # + realized_value dict) keep the bf16 storage alive.
# the CUDA memory immediately regardless of Python refcount. # Replacing .data frees it regardless of Python refcount.
# We intentionally skip empty_cache() here so the CUDA
# caching allocator can reuse the freed block for the
# next expert param (same size). Cleanup happens after
# loading completes (model.py:202 and model.py:456-458).
param_value.data = torch.empty(0, device="cpu") param_value.data = torch.empty(0, device="cpu")
gc.collect()
torch.cuda.empty_cache()
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True _moe_load_state["patched"] = True