chore: add log
This commit is contained in:
@@ -77,27 +77,42 @@ def patch_moe_quantization_on_load(cfg):
|
|||||||
|
|
||||||
original_set_param = transformers.core_model_loading.set_param_for_module
|
original_set_param = transformers.core_model_loading.set_param_for_module
|
||||||
|
|
||||||
|
_first_call = [True]
|
||||||
|
|
||||||
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs):
|
||||||
|
if _first_call[0]:
|
||||||
|
LOG.info("MoE quant patch: set_param_for_module intercepted (first call)")
|
||||||
|
_first_call[0] = False
|
||||||
|
|
||||||
original_set_param(model, target_name, param_value, *args, **kwargs)
|
original_set_param(model, target_name, param_value, *args, **kwargs)
|
||||||
|
|
||||||
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
# Quantize 3D+ expert params that BnB skipped (only on CUDA).
|
||||||
if param_value.ndim >= 3 and param_value.is_cuda:
|
if param_value.ndim >= 3:
|
||||||
mod_path, _, pname = target_name.rpartition(".")
|
LOG.info(
|
||||||
mod = model.get_submodule(mod_path) if mod_path else model
|
"MoE quant patch: 3D param %s shape=%s cuda=%s",
|
||||||
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
target_name,
|
||||||
replace_parameter_4bit(
|
param_value.shape,
|
||||||
mod,
|
param_value.is_cuda,
|
||||||
pname,
|
)
|
||||||
compress_statistics=_moe_load_state["compress_statistics"],
|
if param_value.is_cuda:
|
||||||
quant_type=_moe_load_state["quant_type"],
|
mod_path, _, pname = target_name.rpartition(".")
|
||||||
)
|
mod = model.get_submodule(mod_path) if mod_path else model
|
||||||
torch.cuda.empty_cache()
|
if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
|
||||||
_moe_load_state["count"] += 1
|
replace_parameter_4bit(
|
||||||
LOG.debug(
|
mod,
|
||||||
"Quantized 3D expert param during loading: %s (shape %s)",
|
pname,
|
||||||
target_name,
|
compress_statistics=_moe_load_state["compress_statistics"],
|
||||||
param_value.shape,
|
quant_type=_moe_load_state["quant_type"],
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
_moe_load_state["count"] += 1
|
||||||
|
LOG.info(
|
||||||
|
"Quantized 3D expert param: %s "
|
||||||
|
"(alloc=%.2f GiB, reserved=%.2f GiB)",
|
||||||
|
target_name,
|
||||||
|
torch.cuda.memory_allocated() / 1024**3,
|
||||||
|
torch.cuda.memory_reserved() / 1024**3,
|
||||||
|
)
|
||||||
|
|
||||||
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module
|
||||||
_moe_load_state["patched"] = True
|
_moe_load_state["patched"] = True
|
||||||
|
|||||||
Reference in New Issue
Block a user