diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 6c8885526..eb5c087c4 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -173,6 +173,14 @@ class ModelLoader: PLUGIN_MANAGER.pre_model_load(self.cfg) self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() + + # Quantize 3D MoE expert nn.Parameter tensors that BnB skips during loading. + self.model._moe_experts_quantized = False + if self.cfg.adapter in ("qlora", "lora") and self.cfg.load_in_4bit: + from axolotl.monkeypatch.moe_quant import quantize_moe_expert_params + + self.model._moe_experts_quantized = quantize_moe_expert_params(self.model) + PLUGIN_MANAGER.post_model_build(self.cfg, self.model) # Post-build model configuration @@ -860,6 +868,10 @@ class ModelLoader: # Make sure everything is in the same dtype skip_prepare_model_for_kbit_training = True + if getattr(self.model, "_moe_experts_quantized", False): + # Parametrized expert tensors dequantize on access — would OOM. + skip_prepare_model_for_kbit_training = True + if ( not skip_prepare_model_for_kbit_training and self.cfg.adapter in ["lora", "qlora"] diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py new file mode 100644 index 000000000..c9a03e221 --- /dev/null +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -0,0 +1,86 @@ +""" +Post-load quantization for MoE expert weights stored as 3D nn.Parameter tensors. + +In transformers v5, many MoE models store expert weights as fused 3D nn.Parameter +tensors instead of individual nn.Linear modules. BnB 4-bit quantization only targets +nn.Linear, so these expert weights are skipped during model loading, causing OOM. + +This module provides a post-load fixup that quantizes those skipped parameters using +bitsandbytes.nn.parametrize.replace_parameter_4bit (requires bitsandbytes >= 0.48.0). +PEFT's target_parameters / ParamWrapper can then apply LoRA on top of these quantized +params via stacked parametrizations. +""" + +import bitsandbytes as bnb +import torch + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def find_unquantized_expert_params(model): + """Find 3D+ nn.Parameter tensors that BnB quantization skipped. + + Returns: + List of (module, param_name) tuples to quantize. + """ + params_to_quantize = [] + for _, module in model.named_modules(): + if isinstance(module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)): + continue + for param_name, param in module.named_parameters(recurse=False): + if param.ndim >= 3 and any( + kw in param_name for kw in ("experts", "gate_up_proj", "down_proj") + ): + params_to_quantize.append((module, param_name)) + return params_to_quantize + + +def quantize_moe_expert_params(model, quant_type=None, compress_statistics=None): + """Quantize 3D nn.Parameter expert weights that BnB skips during model loading. + + Reads quant_type and compress_statistics from the model's quantization_config + when not explicitly provided, so that the same settings used for nn.Linear + quantization are applied to the MoE expert parameters. + """ + from bitsandbytes.nn.parametrize import replace_parameter_4bit + + params_to_quantize = find_unquantized_expert_params(model) + if not params_to_quantize: + return False + + # Derive settings from model's BnB config if not explicitly provided + if quant_type is None or compress_statistics is None: + bnb_config = getattr(model.config, "quantization_config", None) + if bnb_config is not None: + if quant_type is None: + quant_type = getattr(bnb_config, "bnb_4bit_quant_type", "nf4") + if compress_statistics is None: + compress_statistics = getattr( + bnb_config, "bnb_4bit_use_double_quant", True + ) + # Final defaults + if quant_type is None: + quant_type = "nf4" + if compress_statistics is None: + compress_statistics = True + + count = 0 + for module, param_name in params_to_quantize: + replace_parameter_4bit( + module, + param_name, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + count += 1 + + torch.cuda.empty_cache() + LOG.info( + "Quantized %d MoE expert parameters to 4-bit (quant_type=%s, compress_statistics=%s)", + count, + quant_type, + compress_statistics, + ) + return True