From a806704e94b6f6d988072ccd073acc38724da5c8 Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Mon, 16 Mar 2026 07:40:30 +0530 Subject: [PATCH] moe quant patch for merge miss match (#3483) * moe quant patch for merge miss match * lint * revert test + fix moe patch * comment fixxes * e2e tests * mismatch fixx tested * mis match fix wwith vllm compatablity + test * comment lint * fix: missing os import, duplicate no op * chore: simplify comments --------- Co-authored-by: NanoCode012 --- src/axolotl/loaders/patch_manager.py | 13 +- src/axolotl/monkeypatch/moe_quant.py | 124 ++++++++++++++---- .../schemas/validation/test_moe_quant.py | 116 ++++++++++++++++ 3 files changed, 220 insertions(+), 33 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f516166c3..857a2f76f 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -416,16 +416,21 @@ class PatchManager: os.environ["HF_DEACTIVATE_ASYNC_LOAD"] = "1" def _apply_moe_expert_quantization_patch(self): - """Patch transformers weight loading to quantize MoE expert params on-the-fly.""" - if not self.cfg.quantize_moe_experts: + """Patch transformers weight loading and PEFT for MoE expert quantization.""" + has_target_params = bool(getattr(self.cfg, "lora_target_parameters", None)) + + if not self.cfg.quantize_moe_experts and not has_target_params: return from axolotl.monkeypatch.moe_quant import ( - patch_moe_quantization_on_load, patch_peft_target_parameters_matching, ) - patch_moe_quantization_on_load(self.cfg) + if self.cfg.quantize_moe_experts: + from axolotl.monkeypatch.moe_quant import patch_moe_quantization_on_load + + patch_moe_quantization_on_load(self.cfg) + patch_peft_target_parameters_matching() def _finalize_moe_expert_quantization(self, model: PreTrainedModel): diff --git a/src/axolotl/monkeypatch/moe_quant.py b/src/axolotl/monkeypatch/moe_quant.py index 42beec6a9..983da4a37 100644 --- a/src/axolotl/monkeypatch/moe_quant.py +++ b/src/axolotl/monkeypatch/moe_quant.py @@ -1,11 +1,4 @@ -""" -Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors. - -In transformers v5, MoE models store expert weights as fused 3D tensors that BnB -skips (only targets nn.Linear). This module patches weight loading to quantize them -on-the-fly (4-bit via bitsandbytes parametrize, 8-bit via custom int8 parametrization), -reducing peak VRAM from "all experts in bf16" to "one expert at a time." -""" +"""Loading-time quantization for MoE expert weights stored as 3D nn.Parameter tensors.""" import bitsandbytes as bnb import torch @@ -15,18 +8,20 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -# Module-level state for the loading-time quantization patch. _moe_load_state = { "count": 0, "mode": "4bit", "quant_type": "nf4", "compress_statistics": True, "patched": False, + # Module path → param names in definition order, captured before quantization. + # Without this, alphabetical loading order would mismatch merge order. + "expert_param_order": {}, } class Bnb8bitParametrization(torch.nn.Module): - """Parametrization that dequantizes int8 row-wise quantized data on access.""" + """Dequantizes int8 row-wise quantized data on access.""" def __init__(self, row_stats: torch.Tensor): super().__init__() @@ -34,7 +29,7 @@ class Bnb8bitParametrization(torch.nn.Module): @torch.no_grad() def forward(self, quantized_param: torch.Tensor) -> torch.Tensor: - # Flatten 3D+ to 2D for BnB's dequant, then reshape back. + """Flatten 3D+ to 2D for BnB's dequant, then reshape back.""" orig_shape = quantized_param.shape if quantized_param.ndim > 2: quantized_param = quantized_param.reshape(-1, orig_shape[-1]) @@ -74,14 +69,11 @@ def replace_parameter_8bit(module, param_name): def patch_moe_quantization_on_load(cfg): - """Patch transformers' weight loading to quantize MoE expert params on-the-fly. - - Wraps ``set_param_for_module`` so that 3D+ CUDA tensors with "expert" in their - name are quantized (4-bit or 8-bit) as they're loaded, keeping peak VRAM low. - """ + """Patch transformers' weight loading to quantize MoE expert params on-the-fly.""" mode = "8bit" if getattr(cfg, "load_in_8bit", False) else "4bit" _moe_load_state["mode"] = mode _moe_load_state["count"] = 0 + _moe_load_state["expert_param_order"] = {} if _moe_load_state["patched"]: LOG.debug("MoE loading-time quantization patch already active") @@ -113,7 +105,6 @@ def patch_moe_quantization_on_load(cfg): def _patched_set_param_for_module(model, target_name, param_value, *args, **kwargs): original_set_param(model, target_name, param_value, *args, **kwargs) - # Quantize 3D+ expert params that BnB skipped (only on CUDA). if param_value.ndim >= 3 and param_value.is_cuda: mod_path, _, pname = target_name.rpartition(".") mod = model.get_submodule(mod_path) if mod_path else model @@ -126,6 +117,13 @@ def patch_moe_quantization_on_load(cfg): ) return + # Record definition order before parametrizations override it + # with alphabetical order. + if mod_path not in _moe_load_state["expert_param_order"]: + _moe_load_state["expert_param_order"][mod_path] = list( + mod._parameters.keys() + ) + if _moe_load_state["mode"] == "4bit": replace_parameter_4bit( mod, @@ -151,20 +149,28 @@ def get_moe_quantized_count(): def patch_peft_target_parameters_matching(): - """Fix PEFT's _inject_parameters to use suffix matching for parametrized modules.""" + """Fix PEFT's _inject_parameters for target_parameters on quantized MoE experts. + + 1. Expands short suffixes to full module paths for parametrized modules. + 2. Iterates params in definition order (not alphabetical order) so saved + adapters are compatible with standard PEFT, vLLM, etc. + """ if getattr(patch_peft_target_parameters_matching, "_axolotl_patched", False): return - from peft.tuners.tuners_utils import BaseTuner - original_inject = BaseTuner._inject_parameters + from contextlib import nullcontext + + from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer + from peft.utils.integrations import init_empty_weights + from peft.utils.other import _get_submodules def _patched_inject_parameters( self, peft_config, model, adapter_name, low_cpu_mem_usage ): - # Patch target_parameters to use full paths for parametrized modules original_targets = list(peft_config.target_parameters) expanded = set(original_targets) + # Expand short suffixes to full paths for parametrized modules. for module_name, module in model.named_modules(): if not hasattr(module, "parametrizations"): continue @@ -175,14 +181,74 @@ def patch_peft_target_parameters_matching(): ) and hasattr(module, param_name): expanded.add(f"{module_name}.{param_name}") - peft_config.target_parameters = sorted(expanded) - try: - return original_inject( - self, peft_config, model, adapter_name, low_cpu_mem_usage - ) - finally: - peft_config.target_parameters = original_targets + target_names_set = expanded + + def strip_base_layer_from_name(module_name): + name = ".base_layer" + while name in module_name: + prefix, _, suffix = module_name.rpartition(name) + module_name = prefix + suffix + return module_name + + def create_and_replace_param(module_name, key, param_name): + parent, target, target_name = _get_submodules(model, module_name) + unwrapped_module_name = strip_base_layer_from_name(module_name) + unwrapped_module = model.get_submodule(unwrapped_module_name) + if ( + isinstance(unwrapped_module, BaseTunerLayer) + and unwrapped_module.__class__.__name__ != "ParamWrapper" + ): + raise ValueError( + f"Trying to wrap an `nn.Parameter` of layer " + f"'{unwrapped_module_name}' of type " + f"{type(target).__name__}, which is not a valid target. " + f"Make sure that this layer is not also targeted with " + f"`target_modules`." + ) + self._check_target_module_compatiblity(peft_config, model, target_name) + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + self._create_and_replace( + peft_config, + adapter_name, + target, + target_name, + parent, + current_key=key, + parameter_name=param_name.rpartition(".")[-1], + ) + + # Use definition order (not alphabetical order) for parametrized modules + # so ParamWrapper nesting matches vanilla PEFT on a plain model. + expert_param_order = _moe_load_state.get("expert_param_order", {}) + + for module_name, module in model.named_modules(): + if hasattr(module, "parametrizations"): + stored_order = expert_param_order.get(module_name) + if stored_order is not None: + params_iter = [ + p for p in stored_order if p in module.parametrizations + ] + else: + # Fallback for paths that bypass model loading (e.g. unit tests). + params_iter = list(module.parametrizations.keys()) + for param_name in params_iter: + key = f"{module_name}.{param_name}" + if (key in target_names_set) or any( + key.endswith(f".{t}") for t in target_names_set + ): + create_and_replace_param(module_name, key, param_name) + self.targeted_parameter_names.append(key) + else: + unwrapped_module_name = strip_base_layer_from_name(module_name) + for param_name, _ in module.named_parameters(recurse=False): + key = f"{unwrapped_module_name}.{param_name}" + if (key in target_names_set) or any( + key.endswith(f".{t}") for t in target_names_set + ): + create_and_replace_param(module_name, key, param_name) + self.targeted_parameter_names.append(key) BaseTuner._inject_parameters = _patched_inject_parameters patch_peft_target_parameters_matching._axolotl_patched = True - LOG.info("Patched PEFT _inject_parameters for parametrized module suffix matching") + LOG.info("Patched PEFT _inject_parameters for consistent ParamWrapper ordering") diff --git a/tests/utils/schemas/validation/test_moe_quant.py b/tests/utils/schemas/validation/test_moe_quant.py index a2121473a..2c34582c3 100644 --- a/tests/utils/schemas/validation/test_moe_quant.py +++ b/tests/utils/schemas/validation/test_moe_quant.py @@ -154,3 +154,119 @@ class TestPeftPatchIdempotency: finally: BaseTuner._inject_parameters = original patch_peft_target_parameters_matching._axolotl_patched = False + + +class TestMoeAdapterTrainMergeRoundtrip: + """E2E: train adapter on quantized MoE experts, then merge onto plain model. + + Verifies that param wrapping order during training matches merge, preventing + size mismatch errors when loading adapters in standard PEFT/vLLM. + """ + + @staticmethod + def _make_classes(): + """Return FakeExperts and FakeModel classes shared by both model builders.""" + import torch + import torch.nn as nn + + class FakeExperts(nn.Module): + def __init__(self): + super().__init__() + # Model definition order: gate_up_proj first, then down_proj. + self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8)) + self.down_proj = nn.Parameter(torch.randn(4, 8, 16)) + + def forward(self, x): + x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16) + x = torch.matmul(x, self.down_proj[0].T) # (batch, 8) + return x + + class FakeModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(8, 8) + self.experts = FakeExperts() + + def forward(self, x): + return self.linear(x) + self.experts(x) + + return FakeExperts, FakeModel + + @staticmethod + def _make_quantized_model(): + """Training model: parametrizations registered in alphabetical order.""" + import torch.nn as nn + import torch.nn.utils.parametrize as P + + from axolotl.monkeypatch.moe_quant import _moe_load_state + + _, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes() + + class PassthroughParametrization(nn.Module): + def forward(self, x): + return x + + model = FakeModel() + + # Record definition order before parametrization (mirrors real loading). + _moe_load_state["expert_param_order"]["experts"] = list( + model.experts._parameters.keys() + ) + + # Register in alphabetical order to expose the ordering mismatch. + P.register_parametrization( + model.experts, "down_proj", PassthroughParametrization(), unsafe=True + ) + P.register_parametrization( + model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True + ) + return model + + @staticmethod + def _make_plain_model(): + """Merge model: no parametrizations — standard branch uses definition order.""" + _, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes() + return FakeModel() + + def test_train_save_merge_no_size_mismatch(self, tmp_path): + """Train on quantized experts, merge onto plain model — must not raise.""" + import torch + from peft import LoraConfig, PeftModel, get_peft_model + from peft.tuners.tuners_utils import BaseTuner + + from axolotl.monkeypatch.moe_quant import ( + _moe_load_state, + patch_peft_target_parameters_matching, + ) + + adapter_dir = tmp_path / "adapter" + lora_cfg = LoraConfig( + r=4, + lora_alpha=8, + target_modules=[], + target_parameters=["experts.gate_up_proj", "experts.down_proj"], + lora_dropout=0.0, + bias="none", + ) + original_inject = BaseTuner._inject_parameters + + # Training phase: quantized model (parametrized branch) with axolotl patch. + _moe_load_state["expert_param_order"] = {} + patch_peft_target_parameters_matching() + try: + peft_model = get_peft_model(self._make_quantized_model(), lora_cfg) + finally: + BaseTuner._inject_parameters = original_inject + patch_peft_target_parameters_matching._axolotl_patched = False + + optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3) + for _ in range(3): + peft_model(torch.randn(2, 8)).sum().backward() + optimizer.step() + optimizer.zero_grad() + peft_model.save_pretrained(str(adapter_dir)) + + # Merge with standard PEFT (no axolotl patch) to verify external compatibility. + loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir)) + merged = loaded.merge_and_unload() + assert merged is not None