chore: add log
This commit is contained in:
@@ -77,27 +77,42 @@ def patch_moe_quantization_on_load(cfg):
|
||||
|
||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||
|
||||
_first_call = [True]
|
||||
|
||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||
if _first_call[0]:
|
||||
LOG.info("MoE quant patch: set_param_for_module intercepted (first call)")
|
||||
_first_call[0] = False
|
||||
|
||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||
|
||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
LOG.debug(
|
||||
"Quantized 3D expert param during loading: %s (shape %s)",
|
||||
target_name,
|
||||
param_value.shape,
|
||||
)
|
||||
if param_value.ndim >= 3:
|
||||
LOG.info(
|
||||
"MoE quant patch: 3D param %s shape=%s cuda=%s",
|
||||
target_name,
|
||||
param_value.shape,
|
||||
param_value.is_cuda,
|
||||
)
|
||||
if param_value.is_cuda:
|
||||
mod_path, _, pname = target_name.rpartition(".")
|
||||
mod = model.get_submodule(mod_path) if mod_path else model
|
||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||
replace_parameter_4bit(
|
||||
mod,
|
||||
pname,
|
||||
compress_statistics=_moe_load_state["compress_statistics"],
|
||||
quant_type=_moe_load_state["quant_type"],
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
_moe_load_state["count"] += 1
|
||||
LOG.info(
|
||||
"Quantized 3D expert param: %s "
|
||||
"(alloc=%.2f GiB, reserved=%.2f GiB)",
|
||||
target_name,
|
||||
torch.cuda.memory_allocated() / 1024**3,
|
||||
torch.cuda.memory_reserved() / 1024**3,
|
||||
)
|
||||
|
||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||
_moe_load_state["patched"] = True
|
||||
|
||||
Reference in New Issue
Block a user