chore: add log

This commit is contained in:
NanoCode012
2026-02-25 17:41:14 +07:00
parent 1558436c69
commit 4b2f568ee0

View File

@@ -77,27 +77,42 @@ def patch_moe_quantization_on_load(cfg):
original_set_param = transformers.core_model_loading.set_param_for_module
_first_call = [True]
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
if _first_call[0]:
LOG.info("MoE quant patch: set_param_for_module intercepted (first call)")
_first_call[0] = False
original_set_param(model, target_name, param_value, *args, **kwargs)
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
if param_value.ndim >= 3 and param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
LOG.debug(
"Quantized 3D expert param during loading: %s (shape %s)",
target_name,
param_value.shape,
)
if param_value.ndim >= 3:
LOG.info(
"MoE quant patch: 3D param %s shape=%s cuda=%s",
target_name,
param_value.shape,
param_value.is_cuda,
)
if param_value.is_cuda:
mod_path, _, pname = target_name.rpartition(".")
mod = model.get_submodule(mod_path) if mod_path else model
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
replace_parameter_4bit(
mod,
pname,
compress_statistics=_moe_load_state["compress_statistics"],
quant_type=_moe_load_state["quant_type"],
)
torch.cuda.empty_cache()
_moe_load_state["count"] += 1
LOG.info(
"Quantized 3D expert param: %s "
"(alloc=%.2f GiB, reserved=%.2f GiB)",
target_name,
torch.cuda.memory_allocated() / 1024**3,
torch.cuda.memory_reserved() / 1024**3,
)
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
_moe_load_state["patched"] = True