From 6ad4b4ecbedf0bf283f0d410308924b26afe4370 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 25 Feb 2026 18:01:58 +0700 Subject: [PATCH] fix: remove cuda alloc for moe and enable async load --- src/axolotl/monkeypatch/moe_quant.py | 76 ++++++++++++---------------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index f9003e879..70fb0726d 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -53,17 +53,25 @@ def patch_moe_quantization_on_load(cfg): LOG.debug("MoE loading-time quantization patch already active") return - import os - import transformers.core_model_loading + import transformers.modeling_utils from bitsandbytes.nn.parametrize import replace_parameter_4bit - # Disable transformers' async weight loading thread pool. Without this, - # the ThreadPoolExecutor pre-fetches tensors to CUDA faster than the main - # loop can quantize them, causing all expert weights to accumulate in bf16 - # on GPU — defeating the purpose of loading-time quantization. - os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1" - LOG.info("Disabled async weight loading (HF_DEACTIVATE_ASYNC_LOAD=1)") + # Patch caching_allocator_warmup to be a no-op. This function pre-allocates + # a single huge GPU tensor equal to the model's total param bytes to warm the + # CUDA caching allocator. For MoE models, it calculates expert params at bf16 + # size (BnB doesn't know we'll quantize them), causing a ~50+ GiB reservation + # that defeats loading-time quantization. Disabling it trades slightly slower + # weight loading for dramatically lower peak VRAM. + _original_warmup = transformers.modeling_utils.caching_allocator_warmup + + def _noop_warmup(*args, **kwargs): + LOG.info( + "Skipped caching_allocator_warmup (MoE loading-time quantization active)" + ) + + transformers.modeling_utils.caching_allocator_warmup = _noop_warmup + LOG.info("Patched caching_allocator_warmup to no-op for MoE quantization") # Read quantization settings from config quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" @@ -77,44 +85,27 @@ def patch_moe_quantization_on_load(cfg): original_set_param = transformers.core_model_loading.set_param_for_module - _first_call = [True] - def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): - if _first_call[0]: - LOG.info( - "MoE quant patch: set_param_for_module intercepted (first call) " - "(alloc=%.2f GiB, reserved=%.2f GiB, max_alloc=%.2f GiB)", - torch.cuda.memory_allocated() / 1024**3, - torch.cuda.memory_reserved() / 1024**3, - torch.cuda.max_memory_allocated() / 1024**3, - ) - _first_call[0] = False - original_set_param(model, target_name, param_value, *args, **kwargs) # Quantize 3D+ expert params that BnB skipped (only on CUDA). - if param_value.ndim >= 3: - LOG.info( - "MoE quant patch: 3D param %s shape=%s cuda=%s", - target_name, - param_value.shape, - param_value.is_cuda, - ) - if param_value.is_cuda: - mod_path, _, pname = target_name.rpartition(".") - mod = model.get_submodule(mod_path) if mod_path else model - if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): - replace_parameter_4bit( - mod, - pname, - compress_statistics=_moe_load_state["compress_statistics"], - quant_type=_moe_load_state["quant_type"], - ) - torch.cuda.empty_cache() - _moe_load_state["count"] += 1 + if param_value.ndim >= 3 and param_value.is_cuda: + mod_path, _, pname = target_name.rpartition(".") + mod = model.get_submodule(mod_path) if mod_path else model + if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + replace_parameter_4bit( + mod, + pname, + compress_statistics=_moe_load_state["compress_statistics"], + quant_type=_moe_load_state["quant_type"], + ) + torch.cuda.empty_cache() + _moe_load_state["count"] += 1 + if _moe_load_state["count"] % 10 == 1: LOG.info( - "Quantized 3D expert param: %s " + "Quantized expert param #%d: %s " "(alloc=%.2f GiB, reserved=%.2f GiB)", + _moe_load_state["count"], target_name, torch.cuda.memory_allocated() / 1024**3, torch.cuda.memory_reserved() / 1024**3, @@ -122,11 +113,6 @@ def patch_moe_quantization_on_load(cfg): transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module _moe_load_state["patched"] = True - LOG.info( - "Pre-load GPU memory: alloc=%.2f GiB, reserved=%.2f GiB", - torch.cuda.memory_allocated() / 1024**3, - torch.cuda.memory_reserved() / 1024**3, - ) LOG.info( "Activated MoE loading-time quantization patch " "(quant_type=%s, compress_statistics=%s)",