Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]
* Add optional Axolotl MoRA/ReMoRA integration Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * Isolate MoRA adapter behavior in plugin Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * Constrain MoRA variants to supported enum values * Keep MoRA validation out of core config --------- Co-authored-by: Swarm <swarm@localhost> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
160
tests/integrations/mora/test_mora.py
Normal file
160
tests/integrations/mora/test_mora.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Integration tests for the MoRA / ReMoRA adapter path."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.mora import plugin as mora_plugin
|
||||
from axolotl.loaders import adapter as adapter_module
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMoraAdapterLoading:
|
||||
"""MoRA adapter selection and config wiring."""
|
||||
|
||||
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_load_lora(*args, **kwargs):
|
||||
calls.append((args, kwargs))
|
||||
return args[0], "adapter-config"
|
||||
|
||||
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert config == "adapter-config"
|
||||
assert calls[0][1]["config_only"] is False
|
||||
|
||||
def test_mora_plugin_raises_when_peft_missing_support(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ImportError, match="MoRA support"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_rejects_quantized_base_model(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"load_in_4bit": True,
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="full-precision base model"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module, "get_peft_model", Mock(return_value=fake_model)
|
||||
)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
assert captured["use_mora"] is True
|
||||
assert captured["mora_type"] == 6
|
||||
assert captured["task_type"].name == "CAUSAL_LM"
|
||||
assert config is not None
|
||||
assert config.use_mora is True
|
||||
assert config.mora_type == 6
|
||||
|
||||
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
"lora_model_dir": "adapter-checkpoint",
|
||||
"lora_on_cpu": False,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
class FakePeftModel:
|
||||
def print_trainable_parameters(self):
|
||||
pass
|
||||
|
||||
def named_parameters(self):
|
||||
return []
|
||||
|
||||
from_pretrained = Mock(return_value=FakePeftModel())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module.PeftModel,
|
||||
"from_pretrained",
|
||||
from_pretrained,
|
||||
)
|
||||
|
||||
peft_model, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert isinstance(peft_model, FakePeftModel)
|
||||
assert config.use_mora is True
|
||||
from_pretrained.assert_called_once()
|
||||
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
|
||||
assert from_pretrained.call_args.kwargs["is_trainable"] is True
|
||||
Reference in New Issue
Block a user