chore: add log

This commit is contained in:
NanoCode012
2026-02-25 17:41:14 +07:00
parent 1558436c69
commit 4b2f568ee0

View File

@@ -77,27 +77,42 @@ def patch_moe_quantization_on_load(cfg):
original_set_param = transformers.core_model_loading.set_param_for_module original_set_param = transformers.core_model_loading.set_param_for_module
_first_call = [True]
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
if _first_call[0]:
LOG.info("MoE quant patch: set_param_for_module intercepted (first call)")
_first_call[0] = False
original_set_param(model, target_name, param_value, *args, **kwargs) original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA). # Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda: if param_value.ndim >= 3:
mod_path, _, pname = target_name.rpartition(".") LOG.info(
mod = model.get_submodule(mod_path) if mod_path else model "MoE quant patch: 3D param %s shape=%s cuda=%s",
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): target_name,
replace_parameter_4bit( param_value.shape,
mod, param_value.is_cuda,
pname, )
compress_statistics=_moe_load_state["compress_statistics"], if param_value.is_cuda:
quant_type=_moe_load_state["quant_type"], mod_path, _, pname = target_name.rpartition(".")
) mod = model.get_submodule(mod_path) if mod_path else model
torch.cuda.empty_cache() if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
_moe_load_state["count"] += 1 replace_parameter_4bit(
LOG.debug( mod,
"Quantized 3D expert param during loading: %s (shape %s)", pname,
target_name, compress_statistics=_moe_load_state["compress_statistics"],
param_value.shape, quant_type=_moe_load_state["quant_type"],
) )
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
LOG.info(
"Quantized 3D expert param: %s "
"(alloc=%.2f GiB, reserved=%.2f GiB)",
target_name,
torch.cuda.memory_allocated() / 1024**3,
torch.cuda.memory_reserved() / 1024**3,
)
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True _moe_load_state["patched"] = True