Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]

* Add optional Axolotl MoRA/ReMoRA integration

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Isolate MoRA adapter behavior in plugin

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Constrain MoRA variants to supported enum values

* Keep MoRA validation out of core config

---------

Co-authored-by: Swarm <swarm@localhost>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
Wing Lian
2026-05-12 07:19:55 -04:00
committed by GitHub
parent e2f01de0e8
commit b7ec06b8a1
12 changed files with 729 additions and 71 deletions

View File

@@ -0,0 +1,160 @@
"""Integration tests for the MoRA / ReMoRA adapter path."""
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import PluginManager
from axolotl.integrations.mora import plugin as mora_plugin
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.dict import DictDefault
class TestMoraAdapterLoading:
"""MoRA adapter selection and config wiring."""
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
calls = []
def fake_load_lora(*args, **kwargs):
calls.append((args, kwargs))
return args[0], "adapter-config"
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
_, config = load_adapter(model, cfg, "mora")
assert config == "adapter-config"
assert calls[0][1]["config_only"] is False
def test_mora_plugin_raises_when_peft_missing_support(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ImportError, match="MoRA support"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_rejects_quantized_base_model(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"load_in_4bit": True,
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ValueError, match="full-precision base model"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {
"use_mora": True,
"mora_type": "rope",
},
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module, "get_peft_model", Mock(return_value=fake_model)
)
_, config = load_adapter(model, cfg, "mora", config_only=True)
assert captured["use_mora"] is True
assert captured["mora_type"] == 6
assert captured["task_type"].name == "CAUSAL_LM"
assert config is not None
assert config.use_mora is True
assert config.mora_type == 6
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
"lora_model_dir": "adapter-checkpoint",
"lora_on_cpu": False,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
class FakeLoraConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class FakePeftModel:
def print_trainable_parameters(self):
pass
def named_parameters(self):
return []
from_pretrained = Mock(return_value=FakePeftModel())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module.PeftModel,
"from_pretrained",
from_pretrained,
)
peft_model, config = load_adapter(model, cfg, "mora")
assert isinstance(peft_model, FakePeftModel)
assert config.use_mora is True
from_pretrained.assert_called_once()
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
assert from_pretrained.call_args.kwargs["is_trainable"] is True

View File

@@ -0,0 +1,73 @@
"""Core adapter plugin registry tests."""
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class FakeAdapterPlugin(BasePlugin):
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)]
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "fake-adapter":
return {}
return {"fake_kwarg": "from-plugin"}
class TestAdapterPluginRegistry:
def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "fake-adapter",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(adapter_module, "get_peft_model", Mock())
_, config = load_adapter(model, cfg, "fake-adapter", config_only=True)
assert config is not None
assert captured["fake_kwarg"] == "from-plugin"
assert captured["task_type"].name == "CAUSAL_LM"
def test_unknown_adapter_error_mentions_plugin_registry(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault({"adapter": "missing-adapter"})
with pytest.raises(NotImplementedError, match="registered by a plugin"):
load_adapter(model, cfg, "missing-adapter")
def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg):
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
cfg = min_base_cfg | DictDefault(
{
"adapter": "fake-adapter",
"relora": True,
"jagged_restart_steps": 100,
}
)
validated = validate_config(cfg)
assert validated.adapter == "fake-adapter"
assert validated.relora is True

View File

@@ -0,0 +1,100 @@
"""Validation tests for the MoRA / ReMoRA integration."""
import pytest
from axolotl.integrations.mora import MoraType
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
class TestMoraValidation:
"""MoRA-specific config validation."""
def test_mora_block_round_trips(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.adapter == "mora"
assert validated.mora.use_mora is True
assert validated.mora.mora_type == MoraType.ROPE
def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 1,
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.mora.mora_type == MoraType.SHARING
def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 2,
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="mora_type"):
validate_config(cfg)
def test_remora_uses_core_relora_fields(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"jagged_restart_steps": 2000,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.relora is True
assert validated.jagged_restart_steps == 2000
def test_remora_still_requires_core_restart_steps(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="jagged_restart_steps"):
validate_config(cfg)