Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]

* Add optional Axolotl MoRA/ReMoRA integration

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Isolate MoRA adapter behavior in plugin

Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>

* Constrain MoRA variants to supported enum values

* Keep MoRA validation out of core config

---------

Co-authored-by: Swarm <swarm@localhost>
Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
Wing Lian
2026-05-12 07:19:55 -04:00
committed by GitHub
parent e2f01de0e8
commit b7ec06b8a1
12 changed files with 729 additions and 71 deletions

View File

@@ -23,9 +23,10 @@ from __future__ import annotations
import collections
import importlib
import traceback
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
from peft import PeftConfig, PeftMixedModel, PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@@ -41,6 +42,15 @@ if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
@dataclass(frozen=True)
class AdapterCapabilities:
"""Capabilities for an adapter contributed by a plugin."""
name: str
lora_like: bool = False
relora: bool = False
class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods.
@@ -91,6 +101,26 @@ class BasePlugin:
Returns a dataclass model for the plugin's training arguments.
"""
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
"""Returns adapter capabilities contributed by the plugin."""
return []
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters."""
return {}
def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Optionally load a plugin adapter instead of the generic loader."""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -414,6 +444,58 @@ class PluginManager:
training_args.append(training_args_from_plugin)
return training_args
def adapter_capabilities(self) -> dict[str, AdapterCapabilities]:
"""Returns adapter capabilities by adapter name."""
capabilities = {}
for plugin in self.plugins.values():
for adapter_capability in plugin.get_adapter_capabilities():
capabilities[adapter_capability.name] = adapter_capability
return capabilities
def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None:
"""Returns capabilities for a registered plugin adapter."""
return self.adapter_capabilities().get(adapter)
def supports_adapter(self, adapter: str) -> bool:
"""Returns whether a plugin has registered the adapter name."""
return adapter in self.adapter_capabilities()
def adapter_supports_relora(self, adapter: str) -> bool:
"""Returns whether a plugin adapter supports ReLoRA restart semantics."""
capability = self.get_adapter_capability(adapter)
return bool(capability and capability.relora)
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra LoraConfig kwargs from plugins for the configured adapter."""
lora_config_kwargs = {}
for plugin in self.plugins.values():
plugin_kwargs = plugin.get_lora_config_kwargs(cfg)
if plugin_kwargs:
lora_config_kwargs.update(plugin_kwargs)
return lora_config_kwargs
def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Returns the first plugin adapter loader result, if any."""
for plugin in self.plugins.values():
loaded = plugin.load_adapter(
model,
cfg,
inference=inference,
config_only=config_only,
)
if loaded is not None:
return loaded
return None
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:

View File

@@ -0,0 +1,6 @@
"""MoRA / ReMoRA integration for Axolotl."""
from .args import MoraArgs, MoraConfig, MoraType
from .plugin import MoraPlugin
__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"]

View File

@@ -0,0 +1,66 @@
"""Config args for MoRA / ReMoRA."""
from __future__ import annotations
from enum import Enum
from pydantic import BaseModel, Field, model_validator
class MoraType(str, Enum):
"""MoRA variants supported by the reference implementation."""
SHARING = "sharing"
ROPE = "rope"
@property
def peft_value(self) -> int:
return {
MoraType.SHARING: 1,
MoraType.ROPE: 6,
}[self]
class MoraConfig(BaseModel):
"""Nested MoRA configuration available under the `mora` key."""
use_mora: bool = Field(
default=True,
description=(
"Enable MoRA adapter construction. Requires a PEFT build with MoRA "
"support (for example, the MoRA fork)."
),
)
mora_type: MoraType = Field(
default=MoraType.ROPE,
description=(
"MoRA variant selector. Supported values are `sharing` for type 1 "
"and `rope` for type 6. Numeric values 1 and 6 are accepted for "
"backwards compatibility."
),
)
@model_validator(mode="before")
@classmethod
def normalize_mora_type(cls, data):
if not isinstance(data, dict) or "mora_type" not in data:
return data
data = data.copy()
mora_type = data["mora_type"]
if mora_type == 1:
data["mora_type"] = MoraType.SHARING
elif mora_type == 6:
data["mora_type"] = MoraType.ROPE
return data
class MoraArgs(BaseModel):
"""Plugin entry that exposes the nested `mora` block to the core config."""
mora: MoraConfig = Field(
default_factory=MoraConfig,
description=(
"MoRA / ReMoRA training configuration. Register the "
"`axolotl.integrations.mora.MoraPlugin` plugin to enable this block."
),
)

View File

@@ -0,0 +1,97 @@
"""MoRA / ReMoRA plugin for Axolotl."""
import inspect
from peft import LoraConfig, PeftModel
from transformers import PreTrainedModel
from axolotl.integrations.base import AdapterCapabilities, BasePlugin
from axolotl.integrations.mora.args import MoraType
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _peft_supports_mora() -> bool:
try:
params = inspect.signature(LoraConfig).parameters
except (TypeError, ValueError):
return False
return "use_mora" in params and "mora_type" in params
def _mora_type_peft_value(mora_type: MoraType | str | int) -> int:
if isinstance(mora_type, MoraType):
return mora_type.peft_value
if mora_type == 1 or mora_type == MoraType.SHARING.value:
return MoraType.SHARING.peft_value
if mora_type == 6 or mora_type == MoraType.ROPE.value:
return MoraType.ROPE.peft_value
raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6")
def _mora_type_label(mora_type: MoraType | str | int) -> str:
if isinstance(mora_type, MoraType):
return mora_type.value
if mora_type == 1:
return MoraType.SHARING.value
if mora_type == 6:
return MoraType.ROPE.value
return str(mora_type)
class MoraPlugin(BasePlugin):
"""Plugin that exposes MoRA-specific config and validates runtime support."""
def get_input_args(self) -> str:
return "axolotl.integrations.mora.MoraArgs"
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="mora", lora_like=True, relora=True)]
def _validate_mora_config(self, cfg: DictDefault):
mora_cfg = getattr(cfg, "mora", None)
if mora_cfg is None:
raise ValueError("adapter: mora requires a nested mora configuration block")
if not getattr(mora_cfg, "use_mora", False):
raise ValueError("mora.use_mora must be true when adapter: mora is set")
if cfg.load_in_4bit or cfg.load_in_8bit:
raise ValueError(
"adapter: mora currently requires a full-precision base model. "
"Use adapter: lora or qlora for quantized training."
)
if cfg.gptq:
raise ValueError(
"adapter: mora is not compatible with GPTQ quantized base models."
)
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "mora":
return {}
self._validate_mora_config(cfg)
if not _peft_supports_mora():
raise ImportError(
"adapter: mora requires a PEFT build with MoRA support "
"(LoraConfig(use_mora=..., mora_type=...)). "
"Install the MoRA fork or another PEFT distribution that exposes "
"those fields."
)
mora_cfg = cfg.mora
return {
"use_mora": mora_cfg.use_mora,
"mora_type": _mora_type_peft_value(mora_cfg.mora_type),
}
def pre_model_load(self, cfg: DictDefault):
if cfg.adapter != "mora":
return
LOG.info("MoRA plugin enabled for adapter: mora")
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
if cfg.adapter == "mora" and getattr(cfg, "mora", None):
LOG.debug(
"Loaded MoRA model with mora_type=%s, relora=%s",
_mora_type_label(cfg.mora.mora_type),
cfg.relora,
)

View File

@@ -19,12 +19,14 @@ from peft import (
)
from transformers import PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.telemetry.errors import send_errors
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()
def setup_quantized_meta_for_peft(model: torch.nn.Module):
@@ -124,6 +126,76 @@ def _patch_peft_clippable_linear():
LoraModel._axolotl_clippable_patched = True
def _get_peft_task_type(model: PreTrainedModel) -> TaskType:
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
return TaskType.SEQ_CLS
if "TokenClassification" in model_cls:
return TaskType.TOKEN_CLS
return TaskType.CAUSAL_LM
def _build_lora_config_kwargs(cfg: DictDefault) -> dict[str, Any]:
lora_config_kwargs: dict[str, Any] = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
return lora_config_kwargs
def _build_peft_lora_config(
model: PreTrainedModel,
cfg: DictDefault,
) -> PeftConfig:
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = _build_lora_config_kwargs(cfg)
lora_config_kwargs.update(PLUGIN_MANAGER.get_lora_config_kwargs(cfg))
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=_get_peft_task_type(model),
**lora_config_kwargs,
)
return lora_config
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
"""Check whether PEFT will auto-populate target_parameters for this model.
@@ -226,62 +298,7 @@ def load_lora(
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
_patch_peft_clippable_linear()
lora_target_modules = cfg.lora_target_modules or []
lora_target_parameters = cfg.lora_target_parameters or []
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_init_lora_weights:
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
LOG.info("Initializing LoRA weights using dora. This might take longer.")
if cfg.peft_use_rslora:
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
if cfg.peft_ensure_weight_tying is not None:
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
# Determine the correct PEFT task type
model_cls = type(model).__name__
if "SequenceClassification" in model_cls:
task_type = TaskType.SEQ_CLS
elif "TokenClassification" in model_cls:
task_type = TaskType.TOKEN_CLS
else:
task_type = TaskType.CAUSAL_LM
lora_config = LoraConfig(
r=cfg.lora_r,
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
target_parameters=lora_target_parameters,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
bias="none",
task_type=task_type,
**lora_config_kwargs,
)
lora_config = _build_peft_lora_config(model, cfg)
if config_only:
return None, lora_config
@@ -315,7 +332,7 @@ def load_lora(
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - LoRA")
LOG.debug("Loading pretrained PEFT adapter")
if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"}
model_kwargs["device_map"] = {"": "cpu"}
@@ -364,30 +381,60 @@ def load_adapter(
cfg: DictDefault,
adapter: str | None,
inference: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
config_only: bool = False,
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
if adapter is None:
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter in ["lora", "qlora"]:
peft_model, lora_config = load_lora(model, cfg, inference=inference)
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
if adapter == "llama-adapter":
if config_only:
_, lora_config = load_llama_adapter(model, cfg, config_only=True)
return None, lora_config
peft_model, lora_config = load_llama_adapter(model, cfg)
return peft_model, lora_config
raise NotImplementedError(f"{adapter} PEFT adapter not available")
plugin_loaded = PLUGIN_MANAGER.load_adapter(
model,
cfg,
inference=inference,
config_only=config_only,
)
if plugin_loaded is not None:
return plugin_loaded
adapter_capability = PLUGIN_MANAGER.get_adapter_capability(adapter)
if adapter_capability and adapter_capability.lora_like:
peft_model, lora_config = load_lora(
model, cfg, inference=inference, config_only=config_only
)
return peft_model, lora_config
registered = sorted(PLUGIN_MANAGER.adapter_capabilities())
registered_msg = ", ".join(registered) if registered else "none"
raise NotImplementedError(
f"Adapter '{adapter}' is not built in and was not registered by a plugin "
f"with loader support. Registered plugin adapters: {registered_msg}"
)
def load_llama_adapter(
model: PreTrainedModel, cfg: DictDefault
) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:
model: PreTrainedModel, cfg: DictDefault, config_only: bool = False
) -> tuple[PeftModel | PeftMixedModel | None, PeftConfig]:
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
adapter_len=cfg.peft_adapter.len, # prompt length (K)
task_type="CAUSAL_LM",
)
if config_only:
return None, peft_config
if cfg.lora_model_dir:
LOG.debug("Loading pretrained PEFT - llama_adapter")
peft_model = PeftModel.from_pretrained(

View File

@@ -39,7 +39,7 @@ from transformers.integrations.deepspeed import (
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.loaders.adapter import load_adapter, load_lora
from axolotl.loaders.adapter import load_adapter
from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.patch_manager import PatchManager
from axolotl.loaders.utils import (
@@ -386,8 +386,12 @@ class ModelLoader:
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
and not self.cfg.merge_lora
):
_, lora_config = load_lora(
self.model, self.cfg, inference=False, config_only=True
_, lora_config = load_adapter(
self.model,
self.cfg,
self.cfg.adapter,
inference=False,
config_only=True,
)
else:
self.model, lora_config = load_adapter(

View File

@@ -285,7 +285,9 @@ def save_trained_model(
)
# Handle ReLoRA early return case
if cfg.relora:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
if hasattr(model, "merge_and_unload") and not (
cfg.load_in_4bit or cfg.load_in_8bit
):
model = model.merge_and_unload()
else:
# final model weights have already been saved by `ReLoRACallback.on_train_end`

View File

@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field(
adapter: str | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model"
"description": "If you want to use a built-in or plugin adapter, or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
@@ -174,6 +174,16 @@ class LoraConfig(BaseModel):
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
adapter = data.get("adapter")
if adapter and adapter not in ("lora", "qlora", "llama-adapter"):
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
if not plugin_manager.supports_adapter(adapter):
raise ValueError(
f"Adapter '{adapter}' is not built in and was not registered by "
"a plugin. Add the plugin that provides this adapter to `plugins:`."
)
return data
@model_validator(mode="after")

View File

@@ -1478,8 +1478,19 @@ class ComplexValidationMixin:
if self.relora:
if not self.jagged_restart_steps:
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
adapter_supports_relora = self.adapter in ("lora", "qlora")
if self.adapter and not adapter_supports_relora:
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
adapter_supports_relora = plugin_manager.adapter_supports_relora(
self.adapter
)
if not adapter_supports_relora:
raise ValueError(
"cfg.adapter must support ReLoRA to use ReLoRA restart semantics"
)
if self.fsdp or self.fsdp_config:
raise ValueError("fsdp not supported with ReLoRA")

View File

@@ -0,0 +1,160 @@
"""Integration tests for the MoRA / ReMoRA adapter path."""
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import PluginManager
from axolotl.integrations.mora import plugin as mora_plugin
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.dict import DictDefault
class TestMoraAdapterLoading:
"""MoRA adapter selection and config wiring."""
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
calls = []
def fake_load_lora(*args, **kwargs):
calls.append((args, kwargs))
return args[0], "adapter-config"
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
_, config = load_adapter(model, cfg, "mora")
assert config == "adapter-config"
assert calls[0][1]["config_only"] is False
def test_mora_plugin_raises_when_peft_missing_support(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ImportError, match="MoRA support"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_rejects_quantized_base_model(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"load_in_4bit": True,
"mora": {"use_mora": True, "mora_type": "rope"},
}
)
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
with pytest.raises(ValueError, match="full-precision base model"):
load_adapter(model, cfg, "mora", config_only=True)
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {
"use_mora": True,
"mora_type": "rope",
},
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module, "get_peft_model", Mock(return_value=fake_model)
)
_, config = load_adapter(model, cfg, "mora", config_only=True)
assert captured["use_mora"] is True
assert captured["mora_type"] == 6
assert captured["task_type"].name == "CAUSAL_LM"
assert config is not None
assert config.use_mora is True
assert config.mora_type == 6
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "mora",
"mora": {"use_mora": True, "mora_type": "rope"},
"lora_model_dir": "adapter-checkpoint",
"lora_on_cpu": False,
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
class FakeLoraConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
class FakePeftModel:
def print_trainable_parameters(self):
pass
def named_parameters(self):
return []
from_pretrained = Mock(return_value=FakePeftModel())
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
mora_plugin.MoraPlugin()
)
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(
adapter_module.PeftModel,
"from_pretrained",
from_pretrained,
)
peft_model, config = load_adapter(model, cfg, "mora")
assert isinstance(peft_model, FakePeftModel)
assert config.use_mora is True
from_pretrained.assert_called_once()
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
assert from_pretrained.call_args.kwargs["is_trainable"] is True

View File

@@ -0,0 +1,73 @@
"""Core adapter plugin registry tests."""
from unittest.mock import Mock
import pytest
import torch
from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager
from axolotl.loaders import adapter as adapter_module
from axolotl.loaders.adapter import load_adapter
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class FakeAdapterPlugin(BasePlugin):
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)]
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "fake-adapter":
return {}
return {"fake_kwarg": "from-plugin"}
class TestAdapterPluginRegistry:
def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch):
model = torch.nn.Linear(4, 4)
cfg = DictDefault(
{
"adapter": "fake-adapter",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
}
)
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
captured = {}
class FakeLoraConfig:
def __init__(self, **kwargs):
captured.update(kwargs)
self.__dict__.update(kwargs)
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
monkeypatch.setattr(adapter_module, "get_peft_model", Mock())
_, config = load_adapter(model, cfg, "fake-adapter", config_only=True)
assert config is not None
assert captured["fake_kwarg"] == "from-plugin"
assert captured["task_type"].name == "CAUSAL_LM"
def test_unknown_adapter_error_mentions_plugin_registry(self):
model = torch.nn.Linear(4, 4)
cfg = DictDefault({"adapter": "missing-adapter"})
with pytest.raises(NotImplementedError, match="registered by a plugin"):
load_adapter(model, cfg, "missing-adapter")
def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg):
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
cfg = min_base_cfg | DictDefault(
{
"adapter": "fake-adapter",
"relora": True,
"jagged_restart_steps": 100,
}
)
validated = validate_config(cfg)
assert validated.adapter == "fake-adapter"
assert validated.relora is True

View File

@@ -0,0 +1,100 @@
"""Validation tests for the MoRA / ReMoRA integration."""
import pytest
from axolotl.integrations.mora import MoraType
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
class TestMoraValidation:
"""MoRA-specific config validation."""
def test_mora_block_round_trips(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.adapter == "mora"
assert validated.mora.use_mora is True
assert validated.mora.mora_type == MoraType.ROPE
def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 1,
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.mora.mora_type == MoraType.SHARING
def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"mora": {
"use_mora": True,
"mora_type": 2,
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="mora_type"):
validate_config(cfg)
def test_remora_uses_core_relora_fields(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"jagged_restart_steps": 2000,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
validated = validate_config(cfg)
assert validated.relora is True
assert validated.jagged_restart_steps == 2000
def test_remora_still_requires_core_restart_steps(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
{
"adapter": "mora",
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
"relora": True,
"mora": {
"use_mora": True,
"mora_type": "rope",
},
}
)
prepare_plugins(cfg)
with pytest.raises(ValueError, match="jagged_restart_steps"):
validate_config(cfg)