From d3d6cb6b67f4e06704b9a2e926460c4f77799a7b Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 25 Feb 2026 17:20:40 +0700 Subject: [PATCH] fix: attempt on-load quantize experts instead of post-load --- src/axolotl/loaders/model.py | 22 +++- src/axolotl/monkeypatch/moe_quant.py | 163 +++++++++++++++------------ 2 files changed, 107 insertions(+), 78 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 81866347b..742a3fb92 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -172,18 +172,30 @@ class ModelLoader: # Build the model PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() + + # Activate loading-time quantization for 3D MoE expert params before + # from_pretrained() runs. This patches set_param_for_module so each + # expert weight is quantized to 4-bit as it's loaded, keeping peak + # VRAM to one expert param in bf16 at a time. + moe_quant_active = False + if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit: + from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load + + patch_moe_quantization_on_load(self.cfg) + moe_quant_active = True + skip_move_to_device = self._build_model() - # Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading. + # Check if any MoE expert params were quantized during loading. self.model._moe_experts_quantized = False - if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit: + if moe_quant_active: from axolotl.monkeypatch.moe_quant import ( + get_moe_quantized_count, patch_peft_target_parameters_matching, - quantize_moe_expert_params, ) - self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) - if self.model._moe_experts_quantized: + if get_moe_quantized_count() > 0: + self.model._moe_experts_quantized = True patch_peft_target_parameters_matching() PLUGIN_MANAGER.post_model_build(self.cfg, self.model) diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index be3b3e1e1..ad81556da 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -1,14 +1,20 @@ """ -Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. +Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors. In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets -nn.Linear, so these expert weights are skipped during model loading, causing OOM. +nn.Linear, so these expert weights are loaded in full precision, causing high peak VRAM. + +This module patches transformers' weight loading to quantize 3D expert parameters +on-the-fly as they're assigned to modules, using bitsandbytes.nn.parametrize. +replace_parameter_4bit (requires bitsandbytes >= 0.48.0). This reduces peak VRAM +from "all experts in bf16" to "one expert param in bf16 at a time." -This module provides a post-load fixup that quantizes those skipped parameters using -bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized params via stacked parametrizations. + +Note: FSDP2 cpu ram efficient loading and Tensor Parallel (DTensor) compatibility +with parametrization is untested. """ import bitsandbytes as bnb @@ -18,6 +24,86 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +# Module-level state for the loading-time quantization patch. +_moe_load_state = { + "count": 0, + "quant_type": "nf4", + "compress_statistics": True, + "patched": False, +} + + +def patch_moe_quantization_on_load(cfg): + """Patch transformers' weight loading to quantize 3D MoE expert params on-the-fly. + + Wraps ``transformers.core_model_loading.set_param_for_module`` so that after each + parameter is assigned to its module, any 3D+ tensor on CUDA that BnB skipped + (i.e. not inside a Linear4bit/Linear8bitLt) is immediately quantized via + ``replace_parameter_4bit``. This keeps peak VRAM to one expert param in bf16 + at a time, instead of loading all experts in bf16 first. + + The patch stays active permanently — the ``ndim >= 3`` and ``is_cuda`` checks + make it safe for non-MoE models (no false positives). + + Args: + cfg: Axolotl DictDefault config. Reads bnb_4bit_quant_type and + bnb_4bit_use_double_quant for quantization settings. + """ + if _moe_load_state["patched"]: + LOG.debug("MoE loading-time quantization patch already active") + return + + import transformers.core_model_loading + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + # Read quantization settings from config + quant_type = getattr(cfg, "bnb_4bit_quant_type", None) or "nf4" + compress_statistics = getattr(cfg, "bnb_4bit_use_double_quant", None) + if compress_statistics is None: + compress_statistics = True + + _moe_load_state["quant_type"] = quant_type + _moe_load_state["compress_statistics"] = compress_statistics + _moe_load_state["count"] = 0 + + original_set_param = transformers.core_model_loading.set_param_for_module + + def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): + original_set_param(model, target_name, param_value, *args, **kwargs) + + # Quantize 3D+ expert params that BnB skipped (only on CUDA). + if param_value.ndim >= 3 and param_value.is_cuda: + mod_path, _, pname = target_name.rpartition(".") + mod = model.get_submodule(mod_path) if mod_path else model + if not isinstance(mod, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + replace_parameter_4bit( + mod, + pname, + compress_statistics=_moe_load_state["compress_statistics"], + quant_type=_moe_load_state["quant_type"], + ) + torch.cuda.empty_cache() + _moe_load_state["count"] += 1 + LOG.debug( + "Quantized 3D expert param during loading: %s (shape %s)", + target_name, + param_value.shape, + ) + + transformers.core_model_loading.set_param_for_module = _patched_set_param_for_module + _moe_load_state["patched"] = True + LOG.info( + "Activated MoE loading-time quantization patch " + "(quant_type=%s, compress_statistics=%s)", + quant_type, + compress_statistics, + ) + + +def get_moe_quantized_count(): + """Return the number of expert parameters quantized during loading.""" + return _moe_load_state["count"] + def patch_peft_target_parameters_matching(): """Fix PEFT's _inject_parameters to use suffix matching for parametrized modules. @@ -59,72 +145,3 @@ def patch_peft_target_parameters_matching(): BaseTuner._inject_parameters = _patched_inject_parameters LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching") - - -def find_unquantized_expert_params(model): - """Find 3D+ nn.Parameter tensors that BnB quantization skipped. - - Returns: - List of (module, param_name) tuples to quantize. - """ - params_to_quantize = [] - for _, module in model.named_modules(): - if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): - continue - for param_name, param in module.named_parameters(recurse=False): - if param.ndim >= 3 and any( - kw in param_name for kw in ("experts", "gate_up_proj", "down_proj") - ): - params_to_quantize.append((module, param_name)) - return params_to_quantize - - -def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None): - """Quantize 3D nn.Parameter expert weights that BnB skips during model loading. - - Reads quant_type and compress_statistics from the model's quantization_config - when not explicitly provided, so that the same settings used for nn.Linear - quantization are applied to the MoE expert parameters. - """ - from bitsandbytes.nn.parametrize import replace_parameter_4bit - - params_to_quantize = find_unquantized_expert_params(model) - if not params_to_quantize: - return False - - # Derive settings from model's BnB config if not explicitly provided - if quant_type is None or compress_statistics is None: - bnb_config = getattr(model.config, "quantization_config", None) - if bnb_config is not None: - if quant_type is None: - quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4") - if compress_statistics is None: - compress_statistics = getattr( - bnb_config, "bnb_4bit_use_double_quant", True - ) - # Final defaults - if quant_type is None: - quant_type = "nf4" - if compress_statistics is None: - compress_statistics = True - - count = 0 - for module, param_name in params_to_quantize: - replace_parameter_4bit( - module, - param_name, - compress_statistics=compress_statistics, - quant_type=quant_type, - ) - count += 1 - # Free the bf16 → 4-bit conversion buffers after each parameter - # to avoid accumulating peak reserved VRAM. - torch.cuda.empty_cache() - - LOG.info( - "Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)", - count, - quant_type, - compress_statistics, - ) - return True