Fix: quantize and target moe layers in transformers v5 for adapters and many misc fixes (#3439)
* fix: saving clones state dict
* fix: apply fix for only CP mode
* fix: add dropout check when using lora target param
* fix: re-add patch from transformers PR #39866
* feat: add moe quant to test by ved
* fix: try match target param properly end with
* fix: clear cache per param quant
* fix: attempt on-load quantize experts instead of post-load
* fix: attempt disable async load
* chore: add log
* chore: adjust log
* fix: remove cuda alloc for moe and enable async load
* chore: remove leftover logs
* chore: add extra empty cache
* fix(doc): clarify support
* fix: handle fsdp2 for paramwrapper dtensor
* feat: attempt to quant experts in 8bit mode too
* feat: attempt to release bf16 experts from vram
* feat: upgrade cce
* fix: fsdp2 init_sharded_param load int8/uint4 dtensor as
require_grad=true on init
* fix: remove unnecessary gc and empty cache
* Revert "fix: remove unnecessary gc and empty cache"
This reverts commit 1d54518990.
* fix: do not call full_tensor on non-dtensors
* fix: attempt to address fsdp2 with quant exp high loss
* fix: attempt lora quant experts wrong dim
* fix: ensure require_grad patch applied for lora 8bit
* fix: attempt lora 8bit fsdp2
* fix: attribute access on save for lora 8bit fsdp2
* fix: wrong weight attrib access
* chore(refactor): add config, re-arrange position of patches, clean
comments
* feat: add example docs
* chore: cherry pick trinity fixes from PR 3399
* chore: comments refactor; add guards
* fix: guard using wrong key
* fix: mamba save does not accept main process param
* fix: guard prevent double hook
* fix: move gc to upper scope
* chore: add comment on proxy forward patch
* fix: add comment to clarify
* feat: add test idempotency
* fix: AttributeError: `e_score_correction_bias` is not an nn.Parameter
* fix: AttributeError: 'NoneType' object has no attribute 'to'
* fix: update docs on cpu_ram_efficient_loading
This commit is contained in:
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
142
tests/utils/schemas/validation/test_moe_quant.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for MoE expert quantization config validation and PEFT patch idempotency."""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def gpu_caps():
|
||||
return {"compute_capability": "sm_89", "bf16": True, "n_gpu": 1, "n_node": 1}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def env_caps():
|
||||
return {"torch_version": "2.7.0"}
|
||||
|
||||
|
||||
class TestQuantizeMoeExpertsValidation:
|
||||
"""Test suite for quantize_moe_experts config validator."""
|
||||
|
||||
def test_requires_adapter(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts without adapter should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires adapter"):
|
||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
|
||||
def test_requires_quantization(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts without load_in_4bit/8bit should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="lora",
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires load_in_4bit or load_in_8bit"):
|
||||
validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
|
||||
def test_valid_qlora_4bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts with qlora + 4bit should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="qlora",
|
||||
load_in_4bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is True
|
||||
|
||||
def test_valid_lora_8bit(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts with lora + 8bit should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=True,
|
||||
adapter="lora",
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is True
|
||||
|
||||
def test_false_skips_validation(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts=false should not check adapter/quantization."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
quantize_moe_experts=False,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is False
|
||||
|
||||
def test_default_is_false(self, min_base_cfg, gpu_caps, env_caps):
|
||||
"""quantize_moe_experts should default to false."""
|
||||
cfg = DictDefault({}) | min_base_cfg
|
||||
result = validate_config(cfg, capabilities=gpu_caps, env_capabilities=env_caps)
|
||||
assert result["quantize_moe_experts"] is False
|
||||
|
||||
|
||||
class TestLoraTargetParametersDropout:
|
||||
"""Test that lora_dropout must be 0 when lora_target_parameters is set."""
|
||||
|
||||
def test_rejects_nonzero_dropout(self, min_base_cfg):
|
||||
"""lora_dropout > 0 with lora_target_parameters should fail."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||
lora_dropout=0.1,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
with pytest.raises(ValueError, match="lora_dropout must be 0"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_zero_dropout_passes(self, min_base_cfg):
|
||||
"""lora_dropout=0 with lora_target_parameters should pass."""
|
||||
cfg = (
|
||||
DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_parameters=["mlp.experts.gate_up_proj"],
|
||||
lora_dropout=0.0,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
| min_base_cfg
|
||||
)
|
||||
result = validate_config(cfg)
|
||||
assert result["lora_dropout"] == 0.0
|
||||
|
||||
|
||||
class TestPeftPatchIdempotency:
|
||||
"""Test that patch_peft_target_parameters_matching is idempotent."""
|
||||
|
||||
def test_double_call_does_not_stack_wrappers(self):
|
||||
"""Calling patch twice should not double-wrap _inject_parameters."""
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
|
||||
from axolotl.monkeypatch.moe_quant import (
|
||||
patch_peft_target_parameters_matching,
|
||||
)
|
||||
|
||||
original = BaseTuner._inject_parameters
|
||||
try:
|
||||
patch_peft_target_parameters_matching()
|
||||
first_patched = BaseTuner._inject_parameters
|
||||
patch_peft_target_parameters_matching()
|
||||
second_patched = BaseTuner._inject_parameters
|
||||
# Should be same function, not double-wrapped
|
||||
assert first_patched is second_patched
|
||||
finally:
|
||||
BaseTuner._inject_parameters = original
|
||||
patch_peft_target_parameters_matching._axolotl_patched = False
|
||||
Reference in New Issue
Block a user