moe quant patch for merge miss match (#3483)
* moe quant patch for merge miss match * lint * revert test + fix moe patch * comment fixxes * e2e tests * mismatch fixx tested * mis match fix wwith vllm compatablity + test * comment lint * fix: missing os import, duplicate no op * chore: simplify comments --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -154,3 +154,119 @@ class TestPeftPatchIdempotency:
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
|
||||
|
||||
class TestMoeAdapterTrainMergeRoundtrip:
|
||||
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
|
||||
|
||||
Verifies that param wrapping order during training matches merge, preventing
|
||||
size mismatch errors when loading adapters in standard PEFT/vLLM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_classes():
|
||||
"""Return FakeExperts and FakeModel classes shared by both model builders."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class FakeExperts(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Model definition order: gate_up_proj first, then down_proj.
|
||||
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
|
||||
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
|
||||
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
|
||||
return x
|
||||
|
||||
class FakeModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(8, 8)
|
||||
self.experts = FakeExperts()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.experts(x)
|
||||
|
||||
return FakeExperts, FakeModel
|
||||
|
||||
@staticmethod
|
||||
def _make_quantized_model():
|
||||
"""Training model: parametrizations registered in alphabetical order."""
|
||||
import torch.nn as nn
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import _moe_load_state
|
||||
|
||||
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||
|
||||
class PassthroughParametrization(nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
model = FakeModel()
|
||||
|
||||
# Record definition order before parametrization (mirrors real loading).
|
||||
_moe_load_state["expert_param_order"]["experts"] = list(
|
||||
model.experts._parameters.keys()
|
||||
)
|
||||
|
||||
# Register in alphabetical order to expose the ordering mismatch.
|
||||
P.register_parametrization(
|
||||
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
|
||||
)
|
||||
P.register_parametrization(
|
||||
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
|
||||
)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _make_plain_model():
|
||||
"""Merge model: no parametrizations — standard branch uses definition order."""
|
||||
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
||||
return FakeModel()
|
||||
|
||||
def test_train_save_merge_no_size_mismatch(self, tmp_path):
|
||||
"""Train on quantized experts, merge onto plain model — must not raise."""
|
||||
import torch
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
_moe_load_state,
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
adapter_dir = tmp_path / "adapter"
|
||||
lora_cfg = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=8,
|
||||
target_modules=[],
|
||||
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
||||
lora_dropout=0.0,
|
||||
bias="none",
|
||||
)
|
||||
original_inject = BaseTuner._inject_parameters
|
||||
|
||||
# Training phase: quantized model (parametrized branch) with axolotl patch.
|
||||
_moe_load_state["expert_param_order"] = {}
|
||||
patch_peft_target_parameters_matching()
|
||||
try:
|
||||
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original_inject
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
|
||||
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
|
||||
for _ in range(3):
|
||||
peft_model(torch.randn(2, 8)).sum().backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
peft_model.save_pretrained(str(adapter_dir))
|
||||
|
||||
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
|
||||
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
|
||||
merged = loaded.merge_and_unload()
|
||||
assert merged is not None
|
||||
|
||||
Reference in New Issue
Block a user