diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 48260cb4f..5d6517d8d 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -23,9 +23,10 @@ from __future__ import annotations import collections import importlib import traceback +from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, OrderedDict, Union -from peft import PeftModel +from peft import PeftConfig, PeftMixedModel, PeftModel from torch import nn from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -41,6 +42,15 @@ if TYPE_CHECKING: from axolotl.common.datasets import TrainDatasetMeta +@dataclass(frozen=True) +class AdapterCapabilities: + """Capabilities for an adapter contributed by a plugin.""" + + name: str + lora_like: bool = False + relora: bool = False + + class BasePlugin: """Base class for all plugins. Defines the interface for plugin methods. @@ -91,6 +101,26 @@ class BasePlugin: Returns a dataclass model for the plugin's training arguments. """ + def get_adapter_capabilities(self) -> list[AdapterCapabilities]: + """Returns adapter capabilities contributed by the plugin.""" + return [] + + def get_lora_config_kwargs(self, cfg: DictDefault) -> dict: + """Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters.""" + return {} + + def load_adapter( + self, + model: PreTrainedModel, + cfg: DictDefault, + inference: bool = False, + config_only: bool = False, + ) -> ( + tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None] + | None + ): + """Optionally load a plugin adapter instead of the generic loader.""" + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -414,6 +444,58 @@ class PluginManager: training_args.append(training_args_from_plugin) return training_args + def adapter_capabilities(self) -> dict[str, AdapterCapabilities]: + """Returns adapter capabilities by adapter name.""" + capabilities = {} + for plugin in self.plugins.values(): + for adapter_capability in plugin.get_adapter_capabilities(): + capabilities[adapter_capability.name] = adapter_capability + return capabilities + + def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None: + """Returns capabilities for a registered plugin adapter.""" + return self.adapter_capabilities().get(adapter) + + def supports_adapter(self, adapter: str) -> bool: + """Returns whether a plugin has registered the adapter name.""" + return adapter in self.adapter_capabilities() + + def adapter_supports_relora(self, adapter: str) -> bool: + """Returns whether a plugin adapter supports ReLoRA restart semantics.""" + capability = self.get_adapter_capability(adapter) + return bool(capability and capability.relora) + + def get_lora_config_kwargs(self, cfg: DictDefault) -> dict: + """Returns extra LoraConfig kwargs from plugins for the configured adapter.""" + lora_config_kwargs = {} + for plugin in self.plugins.values(): + plugin_kwargs = plugin.get_lora_config_kwargs(cfg) + if plugin_kwargs: + lora_config_kwargs.update(plugin_kwargs) + return lora_config_kwargs + + def load_adapter( + self, + model: PreTrainedModel, + cfg: DictDefault, + inference: bool = False, + config_only: bool = False, + ) -> ( + tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None] + | None + ): + """Returns the first plugin adapter loader result, if any.""" + for plugin in self.plugins.values(): + loaded = plugin.load_adapter( + model, + cfg, + inference=inference, + config_only=config_only, + ) + if loaded is not None: + return loaded + return None + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: diff --git a/src/axolotl/integrations/mora/__init__.py b/src/axolotl/integrations/mora/__init__.py new file mode 100644 index 000000000..8f50258e2 --- /dev/null +++ b/src/axolotl/integrations/mora/__init__.py @@ -0,0 +1,6 @@ +"""MoRA / ReMoRA integration for Axolotl.""" + +from .args import MoraArgs, MoraConfig, MoraType +from .plugin import MoraPlugin + +__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"] diff --git a/src/axolotl/integrations/mora/args.py b/src/axolotl/integrations/mora/args.py new file mode 100644 index 000000000..01549e1db --- /dev/null +++ b/src/axolotl/integrations/mora/args.py @@ -0,0 +1,66 @@ +"""Config args for MoRA / ReMoRA.""" + +from __future__ import annotations + +from enum import Enum + +from pydantic import BaseModel, Field, model_validator + + +class MoraType(str, Enum): + """MoRA variants supported by the reference implementation.""" + + SHARING = "sharing" + ROPE = "rope" + + @property + def peft_value(self) -> int: + return { + MoraType.SHARING: 1, + MoraType.ROPE: 6, + }[self] + + +class MoraConfig(BaseModel): + """Nested MoRA configuration available under the `mora` key.""" + + use_mora: bool = Field( + default=True, + description=( + "Enable MoRA adapter construction. Requires a PEFT build with MoRA " + "support (for example, the MoRA fork)." + ), + ) + mora_type: MoraType = Field( + default=MoraType.ROPE, + description=( + "MoRA variant selector. Supported values are `sharing` for type 1 " + "and `rope` for type 6. Numeric values 1 and 6 are accepted for " + "backwards compatibility." + ), + ) + + @model_validator(mode="before") + @classmethod + def normalize_mora_type(cls, data): + if not isinstance(data, dict) or "mora_type" not in data: + return data + data = data.copy() + mora_type = data["mora_type"] + if mora_type == 1: + data["mora_type"] = MoraType.SHARING + elif mora_type == 6: + data["mora_type"] = MoraType.ROPE + return data + + +class MoraArgs(BaseModel): + """Plugin entry that exposes the nested `mora` block to the core config.""" + + mora: MoraConfig = Field( + default_factory=MoraConfig, + description=( + "MoRA / ReMoRA training configuration. Register the " + "`axolotl.integrations.mora.MoraPlugin` plugin to enable this block." + ), + ) diff --git a/src/axolotl/integrations/mora/plugin.py b/src/axolotl/integrations/mora/plugin.py new file mode 100644 index 000000000..8ca6068c1 --- /dev/null +++ b/src/axolotl/integrations/mora/plugin.py @@ -0,0 +1,97 @@ +"""MoRA / ReMoRA plugin for Axolotl.""" + +import inspect + +from peft import LoraConfig, PeftModel +from transformers import PreTrainedModel + +from axolotl.integrations.base import AdapterCapabilities, BasePlugin +from axolotl.integrations.mora.args import MoraType +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _peft_supports_mora() -> bool: + try: + params = inspect.signature(LoraConfig).parameters + except (TypeError, ValueError): + return False + return "use_mora" in params and "mora_type" in params + + +def _mora_type_peft_value(mora_type: MoraType | str | int) -> int: + if isinstance(mora_type, MoraType): + return mora_type.peft_value + if mora_type == 1 or mora_type == MoraType.SHARING.value: + return MoraType.SHARING.peft_value + if mora_type == 6 or mora_type == MoraType.ROPE.value: + return MoraType.ROPE.peft_value + raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6") + + +def _mora_type_label(mora_type: MoraType | str | int) -> str: + if isinstance(mora_type, MoraType): + return mora_type.value + if mora_type == 1: + return MoraType.SHARING.value + if mora_type == 6: + return MoraType.ROPE.value + return str(mora_type) + + +class MoraPlugin(BasePlugin): + """Plugin that exposes MoRA-specific config and validates runtime support.""" + + def get_input_args(self) -> str: + return "axolotl.integrations.mora.MoraArgs" + + def get_adapter_capabilities(self) -> list[AdapterCapabilities]: + return [AdapterCapabilities(name="mora", lora_like=True, relora=True)] + + def _validate_mora_config(self, cfg: DictDefault): + mora_cfg = getattr(cfg, "mora", None) + if mora_cfg is None: + raise ValueError("adapter: mora requires a nested mora configuration block") + if not getattr(mora_cfg, "use_mora", False): + raise ValueError("mora.use_mora must be true when adapter: mora is set") + if cfg.load_in_4bit or cfg.load_in_8bit: + raise ValueError( + "adapter: mora currently requires a full-precision base model. " + "Use adapter: lora or qlora for quantized training." + ) + if cfg.gptq: + raise ValueError( + "adapter: mora is not compatible with GPTQ quantized base models." + ) + + def get_lora_config_kwargs(self, cfg: DictDefault) -> dict: + if cfg.adapter != "mora": + return {} + self._validate_mora_config(cfg) + if not _peft_supports_mora(): + raise ImportError( + "adapter: mora requires a PEFT build with MoRA support " + "(LoraConfig(use_mora=..., mora_type=...)). " + "Install the MoRA fork or another PEFT distribution that exposes " + "those fields." + ) + mora_cfg = cfg.mora + return { + "use_mora": mora_cfg.use_mora, + "mora_type": _mora_type_peft_value(mora_cfg.mora_type), + } + + def pre_model_load(self, cfg: DictDefault): + if cfg.adapter != "mora": + return + LOG.info("MoRA plugin enabled for adapter: mora") + + def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel): + if cfg.adapter == "mora" and getattr(cfg, "mora", None): + LOG.debug( + "Loaded MoRA model with mora_type=%s, relora=%s", + _mora_type_label(cfg.mora.mora_type), + cfg.relora, + ) diff --git a/src/axolotl/loaders/adapter.py b/src/axolotl/loaders/adapter.py index 3d662c0bb..71506f740 100644 --- a/src/axolotl/loaders/adapter.py +++ b/src/axolotl/loaders/adapter.py @@ -19,12 +19,14 @@ from peft import ( ) from transformers import PreTrainedModel +from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.telemetry.errors import send_errors from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger LOG = get_logger(__name__) +PLUGIN_MANAGER = PluginManager.get_instance() def setup_quantized_meta_for_peft(model: torch.nn.Module): @@ -124,6 +126,76 @@ def _patch_peft_clippable_linear(): LoraModel._axolotl_clippable_patched = True +def _get_peft_task_type(model: PreTrainedModel) -> TaskType: + model_cls = type(model).__name__ + if "SequenceClassification" in model_cls: + return TaskType.SEQ_CLS + if "TokenClassification" in model_cls: + return TaskType.TOKEN_CLS + return TaskType.CAUSAL_LM + + +def _build_lora_config_kwargs(cfg: DictDefault) -> dict[str, Any]: + lora_config_kwargs: dict[str, Any] = {} + loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits + if loftq_bits: + lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) + lora_config_kwargs["init_lora_weights"] = "loftq" + if cfg.peft_init_lora_weights: + lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights + if cfg.peft_use_dora: + lora_config_kwargs["use_dora"] = cfg.peft_use_dora + LOG.info("Initializing LoRA weights using dora. This might take longer.") + if cfg.peft_use_rslora: + lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora + if cfg.peft_layer_replication: + lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication + if cfg.peft_trainable_token_indices: + lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices + if cfg.peft_ensure_weight_tying is not None: + lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying + + return lora_config_kwargs + + +def _build_peft_lora_config( + model: PreTrainedModel, + cfg: DictDefault, +) -> PeftConfig: + lora_target_modules = cfg.lora_target_modules or [] + lora_target_parameters = cfg.lora_target_parameters or [] + + if cfg.lora_target_linear: + linear_names = find_all_linear_names(model) + LOG.info(f"found linear modules: {repr(sorted(linear_names))}") + lora_target_modules_as_list = ( + lora_target_modules + if isinstance(lora_target_modules, list) + else [lora_target_modules] + ) + lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) + + lora_config_kwargs = _build_lora_config_kwargs(cfg) + lora_config_kwargs.update(PLUGIN_MANAGER.get_lora_config_kwargs(cfg)) + + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=lora_target_modules, + target_parameters=lora_target_parameters, + layers_to_transform=cfg.peft_layers_to_transform, + layers_pattern=cfg.peft_layers_pattern, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, + exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None, + bias="none", + task_type=_get_peft_task_type(model), + **lora_config_kwargs, + ) + return lora_config + + def _peft_will_auto_convert_target_params(model, lora_config) -> bool: """Check whether PEFT will auto-populate target_parameters for this model. @@ -226,62 +298,7 @@ def load_lora( config_only: bool = False, ) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: _patch_peft_clippable_linear() - lora_target_modules = cfg.lora_target_modules or [] - lora_target_parameters = cfg.lora_target_parameters or [] - - if cfg.lora_target_linear: - linear_names = find_all_linear_names(model) - LOG.info(f"found linear modules: {repr(sorted(linear_names))}") - lora_target_modules_as_list = ( - lora_target_modules - if isinstance(lora_target_modules, list) - else [lora_target_modules] - ) - lora_target_modules = list(set(lora_target_modules_as_list + linear_names)) - - lora_config_kwargs = {} - loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits - if loftq_bits: - lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) - lora_config_kwargs["init_lora_weights"] = "loftq" - if cfg.peft_init_lora_weights: - lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights - if cfg.peft_use_dora: - lora_config_kwargs["use_dora"] = cfg.peft_use_dora - LOG.info("Initializing LoRA weights using dora. This might take longer.") - if cfg.peft_use_rslora: - lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora - if cfg.peft_layer_replication: - lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication - if cfg.peft_trainable_token_indices: - lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices - if cfg.peft_ensure_weight_tying is not None: - lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying - - # Determine the correct PEFT task type - model_cls = type(model).__name__ - if "SequenceClassification" in model_cls: - task_type = TaskType.SEQ_CLS - elif "TokenClassification" in model_cls: - task_type = TaskType.TOKEN_CLS - else: - task_type = TaskType.CAUSAL_LM - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=lora_target_modules, - target_parameters=lora_target_parameters, - layers_to_transform=cfg.peft_layers_to_transform, - layers_pattern=cfg.peft_layers_pattern, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, - exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None, - bias="none", - task_type=task_type, - **lora_config_kwargs, - ) + lora_config = _build_peft_lora_config(model, cfg) if config_only: return None, lora_config @@ -315,7 +332,7 @@ def load_lora( model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype if cfg.lora_model_dir: - LOG.debug("Loading pretrained PEFT - LoRA") + LOG.debug("Loading pretrained PEFT adapter") if cfg.lora_on_cpu: model_kwargs["max_memory"] = {"cpu": "256GiB"} model_kwargs["device_map"] = {"": "cpu"} @@ -364,30 +381,60 @@ def load_adapter( cfg: DictDefault, adapter: str | None, inference: bool = False, -) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]: + config_only: bool = False, +) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]: if adapter is None: return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if adapter in ["lora", "qlora"]: - peft_model, lora_config = load_lora(model, cfg, inference=inference) + peft_model, lora_config = load_lora( + model, cfg, inference=inference, config_only=config_only + ) return peft_model, lora_config if adapter == "llama-adapter": + if config_only: + _, lora_config = load_llama_adapter(model, cfg, config_only=True) + return None, lora_config peft_model, lora_config = load_llama_adapter(model, cfg) return peft_model, lora_config - raise NotImplementedError(f"{adapter} PEFT adapter not available") + plugin_loaded = PLUGIN_MANAGER.load_adapter( + model, + cfg, + inference=inference, + config_only=config_only, + ) + if plugin_loaded is not None: + return plugin_loaded + + adapter_capability = PLUGIN_MANAGER.get_adapter_capability(adapter) + if adapter_capability and adapter_capability.lora_like: + peft_model, lora_config = load_lora( + model, cfg, inference=inference, config_only=config_only + ) + return peft_model, lora_config + + registered = sorted(PLUGIN_MANAGER.adapter_capabilities()) + registered_msg = ", ".join(registered) if registered else "none" + raise NotImplementedError( + f"Adapter '{adapter}' is not built in and was not registered by a plugin " + f"with loader support. Registered plugin adapters: {registered_msg}" + ) def load_llama_adapter( - model: PreTrainedModel, cfg: DictDefault -) -> tuple[PeftModel | PeftMixedModel, PeftConfig]: + model: PreTrainedModel, cfg: DictDefault, config_only: bool = False +) -> tuple[PeftModel | PeftMixedModel | None, PeftConfig]: peft_config = AdaptionPromptConfig( adapter_layers=cfg.peft_adapter.layers, # layers (L) adapter_len=cfg.peft_adapter.len, # prompt length (K) task_type="CAUSAL_LM", ) + if config_only: + return None, peft_config + if cfg.lora_model_dir: LOG.debug("Loading pretrained PEFT - llama_adapter") peft_model = PeftModel.from_pretrained( diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 061509d39..84a021a9f 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -39,7 +39,7 @@ from transformers.integrations.deepspeed import ( from axolotl.common.architectures import MOE_ARCH_BLOCK from axolotl.integrations.base import PluginManager -from axolotl.loaders.adapter import load_adapter, load_lora +from axolotl.loaders.adapter import load_adapter from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING from axolotl.loaders.patch_manager import PatchManager from axolotl.loaders.utils import ( @@ -386,8 +386,12 @@ class ModelLoader: and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO] and not self.cfg.merge_lora ): - _, lora_config = load_lora( - self.model, self.cfg, inference=False, config_only=True + _, lora_config = load_adapter( + self.model, + self.cfg, + self.cfg.adapter, + inference=False, + config_only=True, ) else: self.model, lora_config = load_adapter( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 0dfeb0c7f..bda07ade9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -285,7 +285,9 @@ def save_trained_model( ) # Handle ReLoRA early return case if cfg.relora: - if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): + if hasattr(model, "merge_and_unload") and not ( + cfg.load_in_4bit or cfg.load_in_8bit + ): model = model.merge_and_unload() else: # final model weights have already been saved by `ReLoRACallback.on_train_end` diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 42fa628e0..4b637c081 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -38,10 +38,10 @@ class LoraConfig(BaseModel): default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"} ) - adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field( + adapter: str | None = Field( default=None, json_schema_extra={ - "description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model" + "description": "If you want to use a built-in or plugin adapter, or leave blank to train all parameters in original model" }, ) lora_model_dir: str | None = Field( @@ -174,6 +174,16 @@ class LoraConfig(BaseModel): "load_in_8bit and load_in_4bit are not supported without setting an adapter for training." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." ) + adapter = data.get("adapter") + if adapter and adapter not in ("lora", "qlora", "llama-adapter"): + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + if not plugin_manager.supports_adapter(adapter): + raise ValueError( + f"Adapter '{adapter}' is not built in and was not registered by " + "a plugin. Add the plugin that provides this adapter to `plugins:`." + ) return data @model_validator(mode="after") diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index ec11d9658..2c580292a 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1478,8 +1478,19 @@ class ComplexValidationMixin: if self.relora: if not self.jagged_restart_steps: raise ValueError("jagged_restart_steps must be set to use ReLoRA") - if self.adapter not in ("lora", "qlora"): - raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") + + adapter_supports_relora = self.adapter in ("lora", "qlora") + if self.adapter and not adapter_supports_relora: + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + adapter_supports_relora = plugin_manager.adapter_supports_relora( + self.adapter + ) + if not adapter_supports_relora: + raise ValueError( + "cfg.adapter must support ReLoRA to use ReLoRA restart semantics" + ) if self.fsdp or self.fsdp_config: raise ValueError("fsdp not supported with ReLoRA") diff --git a/tests/integrations/mora/test_mora.py b/tests/integrations/mora/test_mora.py new file mode 100644 index 000000000..464979cb6 --- /dev/null +++ b/tests/integrations/mora/test_mora.py @@ -0,0 +1,160 @@ +"""Integration tests for the MoRA / ReMoRA adapter path.""" + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +import torch + +from axolotl.integrations.base import PluginManager +from axolotl.integrations.mora import plugin as mora_plugin +from axolotl.loaders import adapter as adapter_module +from axolotl.loaders.adapter import load_adapter +from axolotl.utils.dict import DictDefault + + +class TestMoraAdapterLoading: + """MoRA adapter selection and config wiring.""" + + def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "mora", + "mora": {"use_mora": True, "mora_type": "rope"}, + } + ) + + PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = ( + mora_plugin.MoraPlugin() + ) + + calls = [] + + def fake_load_lora(*args, **kwargs): + calls.append((args, kwargs)) + return args[0], "adapter-config" + + monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora) + + _, config = load_adapter(model, cfg, "mora") + + assert config == "adapter-config" + assert calls[0][1]["config_only"] is False + + def test_mora_plugin_raises_when_peft_missing_support(self): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "mora", + "mora": {"use_mora": True, "mora_type": "rope"}, + } + ) + PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = ( + mora_plugin.MoraPlugin() + ) + + with pytest.raises(ImportError, match="MoRA support"): + load_adapter(model, cfg, "mora", config_only=True) + + def test_mora_plugin_rejects_quantized_base_model(self): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "mora", + "load_in_4bit": True, + "mora": {"use_mora": True, "mora_type": "rope"}, + } + ) + PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = ( + mora_plugin.MoraPlugin() + ) + + with pytest.raises(ValueError, match="full-precision base model"): + load_adapter(model, cfg, "mora", config_only=True) + + def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "mora", + "mora": { + "use_mora": True, + "mora_type": "rope", + }, + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + } + ) + + captured = {} + + class FakeLoraConfig: + def __init__(self, **kwargs): + captured.update(kwargs) + self.__dict__.update(kwargs) + + fake_model = SimpleNamespace(print_trainable_parameters=Mock()) + PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = ( + mora_plugin.MoraPlugin() + ) + monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True) + monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig) + monkeypatch.setattr( + adapter_module, "get_peft_model", Mock(return_value=fake_model) + ) + + _, config = load_adapter(model, cfg, "mora", config_only=True) + + assert captured["use_mora"] is True + assert captured["mora_type"] == 6 + assert captured["task_type"].name == "CAUSAL_LM" + assert config is not None + assert config.use_mora is True + assert config.mora_type == 6 + + def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "mora", + "mora": {"use_mora": True, "mora_type": "rope"}, + "lora_model_dir": "adapter-checkpoint", + "lora_on_cpu": False, + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + } + ) + + class FakeLoraConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + class FakePeftModel: + def print_trainable_parameters(self): + pass + + def named_parameters(self): + return [] + + from_pretrained = Mock(return_value=FakePeftModel()) + PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = ( + mora_plugin.MoraPlugin() + ) + monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True) + monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig) + monkeypatch.setattr( + adapter_module.PeftModel, + "from_pretrained", + from_pretrained, + ) + + peft_model, config = load_adapter(model, cfg, "mora") + + assert isinstance(peft_model, FakePeftModel) + assert config.use_mora is True + from_pretrained.assert_called_once() + assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint") + assert from_pretrained.call_args.kwargs["is_trainable"] is True diff --git a/tests/integrations/test_adapter_plugin_registry.py b/tests/integrations/test_adapter_plugin_registry.py new file mode 100644 index 000000000..e3102a90e --- /dev/null +++ b/tests/integrations/test_adapter_plugin_registry.py @@ -0,0 +1,73 @@ +"""Core adapter plugin registry tests.""" + +from unittest.mock import Mock + +import pytest +import torch + +from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager +from axolotl.loaders import adapter as adapter_module +from axolotl.loaders.adapter import load_adapter +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class FakeAdapterPlugin(BasePlugin): + def get_adapter_capabilities(self) -> list[AdapterCapabilities]: + return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)] + + def get_lora_config_kwargs(self, cfg: DictDefault) -> dict: + if cfg.adapter != "fake-adapter": + return {} + return {"fake_kwarg": "from-plugin"} + + +class TestAdapterPluginRegistry: + def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch): + model = torch.nn.Linear(4, 4) + cfg = DictDefault( + { + "adapter": "fake-adapter", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.0, + } + ) + PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin() + captured = {} + + class FakeLoraConfig: + def __init__(self, **kwargs): + captured.update(kwargs) + self.__dict__.update(kwargs) + + monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig) + monkeypatch.setattr(adapter_module, "get_peft_model", Mock()) + + _, config = load_adapter(model, cfg, "fake-adapter", config_only=True) + + assert config is not None + assert captured["fake_kwarg"] == "from-plugin" + assert captured["task_type"].name == "CAUSAL_LM" + + def test_unknown_adapter_error_mentions_plugin_registry(self): + model = torch.nn.Linear(4, 4) + cfg = DictDefault({"adapter": "missing-adapter"}) + + with pytest.raises(NotImplementedError, match="registered by a plugin"): + load_adapter(model, cfg, "missing-adapter") + + def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg): + PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin() + cfg = min_base_cfg | DictDefault( + { + "adapter": "fake-adapter", + "relora": True, + "jagged_restart_steps": 100, + } + ) + + validated = validate_config(cfg) + + assert validated.adapter == "fake-adapter" + assert validated.relora is True diff --git a/tests/utils/schemas/validation/mora/test_mora_validation.py b/tests/utils/schemas/validation/mora/test_mora_validation.py new file mode 100644 index 000000000..da1169f50 --- /dev/null +++ b/tests/utils/schemas/validation/mora/test_mora_validation.py @@ -0,0 +1,100 @@ +"""Validation tests for the MoRA / ReMoRA integration.""" + +import pytest + +from axolotl.integrations.mora import MoraType +from axolotl.utils.config import prepare_plugins, validate_config +from axolotl.utils.dict import DictDefault + + +class TestMoraValidation: + """MoRA-specific config validation.""" + + def test_mora_block_round_trips(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + { + "adapter": "mora", + "plugins": ["axolotl.integrations.mora.MoraPlugin"], + "mora": { + "use_mora": True, + "mora_type": "rope", + }, + } + ) + + prepare_plugins(cfg) + validated = validate_config(cfg) + + assert validated.adapter == "mora" + assert validated.mora.use_mora is True + assert validated.mora.mora_type == MoraType.ROPE + + def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + { + "adapter": "mora", + "plugins": ["axolotl.integrations.mora.MoraPlugin"], + "mora": { + "use_mora": True, + "mora_type": 1, + }, + } + ) + + prepare_plugins(cfg) + validated = validate_config(cfg) + + assert validated.mora.mora_type == MoraType.SHARING + + def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + { + "adapter": "mora", + "plugins": ["axolotl.integrations.mora.MoraPlugin"], + "mora": { + "use_mora": True, + "mora_type": 2, + }, + } + ) + + prepare_plugins(cfg) + with pytest.raises(ValueError, match="mora_type"): + validate_config(cfg) + + def test_remora_uses_core_relora_fields(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + { + "adapter": "mora", + "plugins": ["axolotl.integrations.mora.MoraPlugin"], + "relora": True, + "jagged_restart_steps": 2000, + "mora": { + "use_mora": True, + "mora_type": "rope", + }, + } + ) + + prepare_plugins(cfg) + validated = validate_config(cfg) + + assert validated.relora is True + assert validated.jagged_restart_steps == 2000 + + def test_remora_still_requires_core_restart_steps(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + { + "adapter": "mora", + "plugins": ["axolotl.integrations.mora.MoraPlugin"], + "relora": True, + "mora": { + "use_mora": True, + "mora_type": "rope", + }, + } + ) + + prepare_plugins(cfg) + with pytest.raises(ValueError, match="jagged_restart_steps"): + validate_config(cfg)