moe quant patch for merge miss match (#3483)

* moe quant patch for merge miss match

* lint

* revert test + fix moe patch

* comment fixxes

* e2e tests

* mismatch fixx tested

* mis match fix wwith vllm compatablity + test

* comment lint

* fix: missing os import, duplicate no op

* chore: simplify comments

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
VED
2026-03-16 07:40:30 +05:30
committed by GitHub
parent d8a05744d7
commit a806704e94
3 changed files with 220 additions and 33 deletions

View File

@@ -154,3 +154,119 @@ class TestPeftPatchIdempotency:
finally:
BaseTuner._inject_parameters = original
patch_peft_target_parameters_matching._axolotl_patched = False
class TestMoeAdapterTrainMergeRoundtrip:
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
Verifies that param wrapping order during training matches merge, preventing
size mismatch errors when loading adapters in standard PEFT/vLLM.
"""
@staticmethod
def _make_classes():
"""Return FakeExperts and FakeModel classes shared by both model builders."""
import torch
import torch.nn as nn
class FakeExperts(nn.Module):
def __init__(self):
super().__init__()
# Model definition order: gate_up_proj first, then down_proj.
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
def forward(self, x):
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
return x
class FakeModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(8, 8)
self.experts = FakeExperts()
def forward(self, x):
return self.linear(x) + self.experts(x)
return FakeExperts, FakeModel
@staticmethod
def _make_quantized_model():
"""Training model: parametrizations registered in alphabetical order."""
import torch.nn as nn
import torch.nn.utils.parametrize as P
from axolotl.monkeypatch.moe_quant import _moe_load_state
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
class PassthroughParametrization(nn.Module):
def forward(self, x):
return x
model = FakeModel()
# Record definition order before parametrization (mirrors real loading).
_moe_load_state["expert_param_order"]["experts"] = list(
model.experts._parameters.keys()
)
# Register in alphabetical order to expose the ordering mismatch.
P.register_parametrization(
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
)
P.register_parametrization(
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
)
return model
@staticmethod
def _make_plain_model():
"""Merge model: no parametrizations — standard branch uses definition order."""
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
return FakeModel()
def test_train_save_merge_no_size_mismatch(self, tmp_path):
"""Train on quantized experts, merge onto plain model — must not raise."""
import torch
from peft import LoraConfig, PeftModel, get_peft_model
from peft.tuners.tuners_utils import BaseTuner
from axolotl.monkeypatch.moe_quant import (
_moe_load_state,
patch_peft_target_parameters_matching,
)
adapter_dir = tmp_path / "adapter"
lora_cfg = LoraConfig(
r=4,
lora_alpha=8,
target_modules=[],
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
lora_dropout=0.0,
bias="none",
)
original_inject = BaseTuner._inject_parameters
# Training phase: quantized model (parametrized branch) with axolotl patch.
_moe_load_state["expert_param_order"] = {}
patch_peft_target_parameters_matching()
try:
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
finally:
BaseTuner._inject_parameters = original_inject
patch_peft_target_parameters_matching._axolotl_patched = False
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
for _ in range(3):
peft_model(torch.randn(2, 8)).sum().backward()
optimizer.step()
optimizer.zero_grad()
peft_model.save_pretrained(str(adapter_dir))
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
merged = loaded.merge_and_unload()
assert merged is not None