diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index b16d30e13..990327be1 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -77,27 +77,42 @@ def patch_moe_quantization_on_load(cfg): original_set_param = transformers.core_model_loading.set_param_for_module + _first_call = [True] + def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): + if _first_call[0]: + LOG.info("MoE quant patch: set_param_for_module intercepted (first call)") + _first_call[0] = False + original_set_param(model, target_name, param_value, *args, **kwargs) # Quantize 3D+ expert params that BnB skipped (only on CUDA). - if param_value.ndim >= 3 and param_value.is_cuda: - mod_path, _, pname = target_name.rpartition(".") - mod = model.get_submodule(mod_path) if mod_path else model - if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): - replace_parameter_4bit( - mod, - pname, - compress_statistics=_moe_load_state["compress_statistics"], - quant_type=_moe_load_state["quant_type"], - ) - torch.cuda.empty_cache() - _moe_load_state["count"] += 1 - LOG.debug( - "Quantized 3D expert param during loading: %s (shape %s)", - target_name, - param_value.shape, - ) + if param_value.ndim >= 3: + LOG.info( + "MoE quant patch: 3D param %s shape=%s cuda=%s", + target_name, + param_value.shape, + param_value.is_cuda, + ) + if param_value.is_cuda: + mod_path, _, pname = target_name.rpartition(".") + mod = model.get_submodule(mod_path) if mod_path else model + if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + replace_parameter_4bit( + mod, + pname, + compress_statistics=_moe_load_state["compress_statistics"], + quant_type=_moe_load_state["quant_type"], + ) + torch.cuda.empty_cache() + _moe_load_state["count"] += 1 + LOG.info( + "Quantized 3D expert param: %s " + "(alloc=%.2f GiB, reserved=%.2f GiB)", + target_name, + torch.cuda.memory_allocated() / 1024**3, + torch.cuda.memory_reserved() / 1024**3, + ) transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module _moe_load_state["patched"] = True