* automatically enable tf32 if supported * update fixtures * handle only when True * Address CR comments * address readability from pr comment * simplify
279 lines
9.9 KiB
Python
279 lines
9.9 KiB
Python
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
|
|
|
import pytest
|
|
|
|
from axolotl.utils.config import validate_config
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
|
|
@pytest.fixture()
|
|
def gpu_caps():
|
|
return {
|
|
"compute_capability": "sm_89",
|
|
"bf16": True,
|
|
"tf32": False,
|
|
"n_gpu": 1,
|
|
"n_node": 1,
|
|
}
|
|
|
|
|
|
@pytest.fixture()
|
|
def env_caps():
|
|
return {"torch_version": "2.7.0"}
|
|
|
|
|
|
class TestQuantizeMoeExpertsValidation:
|
|
"""Test suite for quantize_moe_experts config validator."""
|
|
|
|
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts without adapter should fail."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
with pytest.raises(ValueError, match="requires adapter"):
|
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
|
|
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=True,
|
|
adapter="lora",
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
|
|
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts with qlora + 4bit should pass."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=True,
|
|
adapter="qlora",
|
|
load_in_4bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
assert result["quantize_moe_experts"] is True
|
|
|
|
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts with lora + 8bit should pass."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=True,
|
|
adapter="lora",
|
|
load_in_8bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
assert result["quantize_moe_experts"] is True
|
|
|
|
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts=false should not check adapter/quantization."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=False,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
assert result["quantize_moe_experts"] is False
|
|
|
|
def test_rejects_lora_target_linear(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts with lora_target_linear should fail."""
|
|
cfg = (
|
|
DictDefault(
|
|
quantize_moe_experts=True,
|
|
adapter="qlora",
|
|
load_in_4bit=True,
|
|
lora_target_linear=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
with pytest.raises(ValueError, match="lora_target_linear is not compatible"):
|
|
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
|
|
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
|
"""quantize_moe_experts should default to false."""
|
|
cfg = DictDefault({}) | min_base_cfg
|
|
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
|
assert result["quantize_moe_experts"] is False
|
|
|
|
|
|
class TestLoraTargetParametersDropout:
|
|
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
|
|
|
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
|
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
|
cfg = (
|
|
DictDefault(
|
|
adapter="lora",
|
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
|
lora_dropout=0.1,
|
|
load_in_8bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
|
validate_config(cfg)
|
|
|
|
def test_zero_dropout_passes(self, min_base_cfg):
|
|
"""lora_dropout=0 with lora_target_parameters should pass."""
|
|
cfg = (
|
|
DictDefault(
|
|
adapter="lora",
|
|
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
|
lora_dropout=0.0,
|
|
load_in_8bit=True,
|
|
)
|
|
| min_base_cfg
|
|
)
|
|
result = validate_config(cfg)
|
|
assert result["lora_dropout"] == 0.0
|
|
|
|
|
|
class TestPeftPatchIdempotency:
|
|
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
|
|
|
def test_double_call_does_not_stack_wrappers(self):
|
|
"""Calling patch twice should not double-wrap _inject_parameters."""
|
|
from peft.tuners.tuners_utils import BaseTuner
|
|
|
|
from axolotl.monkeypatch.moe_quant import (
|
|
patch_peft_target_parameters_matching,
|
|
)
|
|
|
|
original = BaseTuner._inject_parameters
|
|
try:
|
|
patch_peft_target_parameters_matching()
|
|
first_patched = BaseTuner._inject_parameters
|
|
patch_peft_target_parameters_matching()
|
|
second_patched = BaseTuner._inject_parameters
|
|
# Should be same function, not double-wrapped
|
|
assert first_patched is second_patched
|
|
finally:
|
|
BaseTuner._inject_parameters = original
|
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
|
|
|
|
|
class TestMoeAdapterTrainMergeRoundtrip:
|
|
"""E2E: train adapter on quantized MoE experts, then merge onto plain model.
|
|
|
|
Verifies that param wrapping order during training matches merge, preventing
|
|
size mismatch errors when loading adapters in standard PEFT/vLLM.
|
|
"""
|
|
|
|
@staticmethod
|
|
def _make_classes():
|
|
"""Return FakeExperts and FakeModel classes shared by both model builders."""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class FakeExperts(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Model definition order: gate_up_proj first, then down_proj.
|
|
self.gate_up_proj = nn.Parameter(torch.randn(4, 16, 8))
|
|
self.down_proj = nn.Parameter(torch.randn(4, 8, 16))
|
|
|
|
def forward(self, x):
|
|
x = torch.matmul(x, self.gate_up_proj[0].T) # (batch, 16)
|
|
x = torch.matmul(x, self.down_proj[0].T) # (batch, 8)
|
|
return x
|
|
|
|
class FakeModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(8, 8)
|
|
self.experts = FakeExperts()
|
|
|
|
def forward(self, x):
|
|
return self.linear(x) + self.experts(x)
|
|
|
|
return FakeExperts, FakeModel
|
|
|
|
@staticmethod
|
|
def _make_quantized_model():
|
|
"""Training model: parametrizations registered in alphabetical order."""
|
|
import torch.nn as nn
|
|
import torch.nn.utils.parametrize as P
|
|
|
|
from axolotl.monkeypatch.moe_quant import _moe_load_state
|
|
|
|
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
|
|
|
class PassthroughParametrization(nn.Module):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
model = FakeModel()
|
|
|
|
# Record definition order before parametrization (mirrors real loading).
|
|
_moe_load_state["expert_param_order"]["experts"] = list(
|
|
model.experts._parameters.keys()
|
|
)
|
|
|
|
# Register in alphabetical order to expose the ordering mismatch.
|
|
P.register_parametrization(
|
|
model.experts, "down_proj", PassthroughParametrization(), unsafe=True
|
|
)
|
|
P.register_parametrization(
|
|
model.experts, "gate_up_proj", PassthroughParametrization(), unsafe=True
|
|
)
|
|
return model
|
|
|
|
@staticmethod
|
|
def _make_plain_model():
|
|
"""Merge model: no parametrizations — standard branch uses definition order."""
|
|
_, FakeModel = TestMoeAdapterTrainMergeRoundtrip._make_classes()
|
|
return FakeModel()
|
|
|
|
def test_train_save_merge_no_size_mismatch(self, tmp_path):
|
|
"""Train on quantized experts, merge onto plain model — must not raise."""
|
|
import torch
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
from peft.tuners.tuners_utils import BaseTuner
|
|
|
|
from axolotl.monkeypatch.moe_quant import (
|
|
_moe_load_state,
|
|
patch_peft_target_parameters_matching,
|
|
)
|
|
|
|
adapter_dir = tmp_path / "adapter"
|
|
lora_cfg = LoraConfig(
|
|
r=4,
|
|
lora_alpha=8,
|
|
target_modules=[],
|
|
target_parameters=["experts.gate_up_proj", "experts.down_proj"],
|
|
lora_dropout=0.0,
|
|
bias="none",
|
|
)
|
|
original_inject = BaseTuner._inject_parameters
|
|
|
|
# Training phase: quantized model (parametrized branch) with axolotl patch.
|
|
_moe_load_state["expert_param_order"] = {}
|
|
patch_peft_target_parameters_matching()
|
|
try:
|
|
peft_model = get_peft_model(self._make_quantized_model(), lora_cfg)
|
|
finally:
|
|
BaseTuner._inject_parameters = original_inject
|
|
patch_peft_target_parameters_matching._axolotl_patched = False
|
|
|
|
optimizer = torch.optim.SGD(peft_model.parameters(), lr=1e-3)
|
|
for _ in range(3):
|
|
peft_model(torch.randn(2, 8)).sum().backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
peft_model.save_pretrained(str(adapter_dir))
|
|
|
|
# Merge with standard PEFT (no axolotl patch) to verify external compatibility.
|
|
loaded = PeftModel.from_pretrained(self._make_plain_model(), str(adapter_dir))
|
|
merged = loaded.merge_and_unload()
|
|
assert merged is not None
|