diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 990327be1..f9003e879 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -81,7 +81,13 @@ def patch_moe_quantization_on_load(cfg): def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): if _first_call[0]: - LOG.info("MoE quant patch: set_param_for_module intercepted (first call)") + LOG.info( + "MoE quant patch: set_param_for_module intercepted (first call) " + "(alloc=%.2f GiB, reserved=%.2f GiB, max_alloc=%.2f GiB)", + torch.cuda.memory_allocated() / 1024**3, + torch.cuda.memory_reserved() / 1024**3, + torch.cuda.max_memory_allocated() / 1024**3, + ) _first_call[0] = False original_set_param(model, target_name, param_value, *args, **kwargs) @@ -116,6 +122,11 @@ def patch_moe_quantization_on_load(cfg): transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module _moe_load_state["patched"] = True + LOG.info( + "Pre-load GPU memory: alloc=%.2f GiB, reserved=%.2f GiB", + torch.cuda.memory_allocated() / 1024**3, + torch.cuda.memory_reserved() / 1024**3, + ) LOG.info( "Activated MoE loading-time quantization patch " "(quant_type=%s, compress_statistics=%s)",