Add optional Axolotl MoRA/ReMoRA integration (#3647) [skip ci]
* Add optional Axolotl MoRA/ReMoRA integration Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * Isolate MoRA adapter behavior in plugin Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * Constrain MoRA variants to supported enum values * Keep MoRA validation out of core config --------- Co-authored-by: Swarm <swarm@localhost> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
This commit is contained in:
@@ -23,9 +23,10 @@ from __future__ import annotations
|
||||
import collections
|
||||
import importlib
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
from peft import PeftConfig, PeftMixedModel, PeftModel
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
@@ -41,6 +42,15 @@ if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AdapterCapabilities:
|
||||
"""Capabilities for an adapter contributed by a plugin."""
|
||||
|
||||
name: str
|
||||
lora_like: bool = False
|
||||
relora: bool = False
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""Base class for all plugins. Defines the interface for plugin methods.
|
||||
|
||||
@@ -91,6 +101,26 @@ class BasePlugin:
|
||||
Returns a dataclass model for the plugin's training arguments.
|
||||
"""
|
||||
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
"""Returns adapter capabilities contributed by the plugin."""
|
||||
return []
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters."""
|
||||
return {}
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> (
|
||||
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
|
||||
| None
|
||||
):
|
||||
"""Optionally load a plugin adapter instead of the generic loader."""
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
@@ -414,6 +444,58 @@ class PluginManager:
|
||||
training_args.append(training_args_from_plugin)
|
||||
return training_args
|
||||
|
||||
def adapter_capabilities(self) -> dict[str, AdapterCapabilities]:
|
||||
"""Returns adapter capabilities by adapter name."""
|
||||
capabilities = {}
|
||||
for plugin in self.plugins.values():
|
||||
for adapter_capability in plugin.get_adapter_capabilities():
|
||||
capabilities[adapter_capability.name] = adapter_capability
|
||||
return capabilities
|
||||
|
||||
def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None:
|
||||
"""Returns capabilities for a registered plugin adapter."""
|
||||
return self.adapter_capabilities().get(adapter)
|
||||
|
||||
def supports_adapter(self, adapter: str) -> bool:
|
||||
"""Returns whether a plugin has registered the adapter name."""
|
||||
return adapter in self.adapter_capabilities()
|
||||
|
||||
def adapter_supports_relora(self, adapter: str) -> bool:
|
||||
"""Returns whether a plugin adapter supports ReLoRA restart semantics."""
|
||||
capability = self.get_adapter_capability(adapter)
|
||||
return bool(capability and capability.relora)
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
"""Returns extra LoraConfig kwargs from plugins for the configured adapter."""
|
||||
lora_config_kwargs = {}
|
||||
for plugin in self.plugins.values():
|
||||
plugin_kwargs = plugin.get_lora_config_kwargs(cfg)
|
||||
if plugin_kwargs:
|
||||
lora_config_kwargs.update(plugin_kwargs)
|
||||
return lora_config_kwargs
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
config_only: bool = False,
|
||||
) -> (
|
||||
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
|
||||
| None
|
||||
):
|
||||
"""Returns the first plugin adapter loader result, if any."""
|
||||
for plugin in self.plugins.values():
|
||||
loaded = plugin.load_adapter(
|
||||
model,
|
||||
cfg,
|
||||
inference=inference,
|
||||
config_only=config_only,
|
||||
)
|
||||
if loaded is not None:
|
||||
return loaded
|
||||
return None
|
||||
|
||||
def load_datasets(
|
||||
self, cfg: DictDefault, preprocess: bool = False
|
||||
) -> Union["TrainDatasetMeta", None]:
|
||||
|
||||
6
src/axolotl/integrations/mora/__init__.py
Normal file
6
src/axolotl/integrations/mora/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""MoRA / ReMoRA integration for Axolotl."""
|
||||
|
||||
from .args import MoraArgs, MoraConfig, MoraType
|
||||
from .plugin import MoraPlugin
|
||||
|
||||
__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"]
|
||||
66
src/axolotl/integrations/mora/args.py
Normal file
66
src/axolotl/integrations/mora/args.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Config args for MoRA / ReMoRA."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class MoraType(str, Enum):
|
||||
"""MoRA variants supported by the reference implementation."""
|
||||
|
||||
SHARING = "sharing"
|
||||
ROPE = "rope"
|
||||
|
||||
@property
|
||||
def peft_value(self) -> int:
|
||||
return {
|
||||
MoraType.SHARING: 1,
|
||||
MoraType.ROPE: 6,
|
||||
}[self]
|
||||
|
||||
|
||||
class MoraConfig(BaseModel):
|
||||
"""Nested MoRA configuration available under the `mora` key."""
|
||||
|
||||
use_mora: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Enable MoRA adapter construction. Requires a PEFT build with MoRA "
|
||||
"support (for example, the MoRA fork)."
|
||||
),
|
||||
)
|
||||
mora_type: MoraType = Field(
|
||||
default=MoraType.ROPE,
|
||||
description=(
|
||||
"MoRA variant selector. Supported values are `sharing` for type 1 "
|
||||
"and `rope` for type 6. Numeric values 1 and 6 are accepted for "
|
||||
"backwards compatibility."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def normalize_mora_type(cls, data):
|
||||
if not isinstance(data, dict) or "mora_type" not in data:
|
||||
return data
|
||||
data = data.copy()
|
||||
mora_type = data["mora_type"]
|
||||
if mora_type == 1:
|
||||
data["mora_type"] = MoraType.SHARING
|
||||
elif mora_type == 6:
|
||||
data["mora_type"] = MoraType.ROPE
|
||||
return data
|
||||
|
||||
|
||||
class MoraArgs(BaseModel):
|
||||
"""Plugin entry that exposes the nested `mora` block to the core config."""
|
||||
|
||||
mora: MoraConfig = Field(
|
||||
default_factory=MoraConfig,
|
||||
description=(
|
||||
"MoRA / ReMoRA training configuration. Register the "
|
||||
"`axolotl.integrations.mora.MoraPlugin` plugin to enable this block."
|
||||
),
|
||||
)
|
||||
97
src/axolotl/integrations/mora/plugin.py
Normal file
97
src/axolotl/integrations/mora/plugin.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""MoRA / ReMoRA plugin for Axolotl."""
|
||||
|
||||
import inspect
|
||||
|
||||
from peft import LoraConfig, PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import AdapterCapabilities, BasePlugin
|
||||
from axolotl.integrations.mora.args import MoraType
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _peft_supports_mora() -> bool:
|
||||
try:
|
||||
params = inspect.signature(LoraConfig).parameters
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return "use_mora" in params and "mora_type" in params
|
||||
|
||||
|
||||
def _mora_type_peft_value(mora_type: MoraType | str | int) -> int:
|
||||
if isinstance(mora_type, MoraType):
|
||||
return mora_type.peft_value
|
||||
if mora_type == 1 or mora_type == MoraType.SHARING.value:
|
||||
return MoraType.SHARING.peft_value
|
||||
if mora_type == 6 or mora_type == MoraType.ROPE.value:
|
||||
return MoraType.ROPE.peft_value
|
||||
raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6")
|
||||
|
||||
|
||||
def _mora_type_label(mora_type: MoraType | str | int) -> str:
|
||||
if isinstance(mora_type, MoraType):
|
||||
return mora_type.value
|
||||
if mora_type == 1:
|
||||
return MoraType.SHARING.value
|
||||
if mora_type == 6:
|
||||
return MoraType.ROPE.value
|
||||
return str(mora_type)
|
||||
|
||||
|
||||
class MoraPlugin(BasePlugin):
|
||||
"""Plugin that exposes MoRA-specific config and validates runtime support."""
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
return "axolotl.integrations.mora.MoraArgs"
|
||||
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
return [AdapterCapabilities(name="mora", lora_like=True, relora=True)]
|
||||
|
||||
def _validate_mora_config(self, cfg: DictDefault):
|
||||
mora_cfg = getattr(cfg, "mora", None)
|
||||
if mora_cfg is None:
|
||||
raise ValueError("adapter: mora requires a nested mora configuration block")
|
||||
if not getattr(mora_cfg, "use_mora", False):
|
||||
raise ValueError("mora.use_mora must be true when adapter: mora is set")
|
||||
if cfg.load_in_4bit or cfg.load_in_8bit:
|
||||
raise ValueError(
|
||||
"adapter: mora currently requires a full-precision base model. "
|
||||
"Use adapter: lora or qlora for quantized training."
|
||||
)
|
||||
if cfg.gptq:
|
||||
raise ValueError(
|
||||
"adapter: mora is not compatible with GPTQ quantized base models."
|
||||
)
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
if cfg.adapter != "mora":
|
||||
return {}
|
||||
self._validate_mora_config(cfg)
|
||||
if not _peft_supports_mora():
|
||||
raise ImportError(
|
||||
"adapter: mora requires a PEFT build with MoRA support "
|
||||
"(LoraConfig(use_mora=..., mora_type=...)). "
|
||||
"Install the MoRA fork or another PEFT distribution that exposes "
|
||||
"those fields."
|
||||
)
|
||||
mora_cfg = cfg.mora
|
||||
return {
|
||||
"use_mora": mora_cfg.use_mora,
|
||||
"mora_type": _mora_type_peft_value(mora_cfg.mora_type),
|
||||
}
|
||||
|
||||
def pre_model_load(self, cfg: DictDefault):
|
||||
if cfg.adapter != "mora":
|
||||
return
|
||||
LOG.info("MoRA plugin enabled for adapter: mora")
|
||||
|
||||
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||
if cfg.adapter == "mora" and getattr(cfg, "mora", None):
|
||||
LOG.debug(
|
||||
"Loaded MoRA model with mora_type=%s, relora=%s",
|
||||
_mora_type_label(cfg.mora.mora_type),
|
||||
cfg.relora,
|
||||
)
|
||||
@@ -19,12 +19,14 @@ from peft import (
|
||||
)
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||
@@ -124,6 +126,76 @@ def _patch_peft_clippable_linear():
|
||||
LoraModel._axolotl_clippable_patched = True
|
||||
|
||||
|
||||
def _get_peft_task_type(model: PreTrainedModel) -> TaskType:
|
||||
model_cls = type(model).__name__
|
||||
if "SequenceClassification" in model_cls:
|
||||
return TaskType.SEQ_CLS
|
||||
if "TokenClassification" in model_cls:
|
||||
return TaskType.TOKEN_CLS
|
||||
return TaskType.CAUSAL_LM
|
||||
|
||||
|
||||
def _build_lora_config_kwargs(cfg: DictDefault) -> dict[str, Any]:
|
||||
lora_config_kwargs: dict[str, Any] = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
if cfg.peft_ensure_weight_tying is not None:
|
||||
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
def _build_peft_lora_config(
|
||||
model: PreTrainedModel,
|
||||
cfg: DictDefault,
|
||||
) -> PeftConfig:
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||
lora_target_modules_as_list = (
|
||||
lora_target_modules
|
||||
if isinstance(lora_target_modules, list)
|
||||
else [lora_target_modules]
|
||||
)
|
||||
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||
|
||||
lora_config_kwargs = _build_lora_config_kwargs(cfg)
|
||||
lora_config_kwargs.update(PLUGIN_MANAGER.get_lora_config_kwargs(cfg))
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=_get_peft_task_type(model),
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
|
||||
def _peft_will_auto_convert_target_params(model, lora_config) -> bool:
|
||||
"""Check whether PEFT will auto-populate target_parameters for this model.
|
||||
|
||||
@@ -226,62 +298,7 @@ def load_lora(
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
_patch_peft_clippable_linear()
|
||||
lora_target_modules = cfg.lora_target_modules or []
|
||||
lora_target_parameters = cfg.lora_target_parameters or []
|
||||
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||
lora_target_modules_as_list = (
|
||||
lora_target_modules
|
||||
if isinstance(lora_target_modules, list)
|
||||
else [lora_target_modules]
|
||||
)
|
||||
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
|
||||
|
||||
lora_config_kwargs = {}
|
||||
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if loftq_bits:
|
||||
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
|
||||
lora_config_kwargs["init_lora_weights"] = "loftq"
|
||||
if cfg.peft_init_lora_weights:
|
||||
lora_config_kwargs["init_lora_weights"] = cfg.peft_init_lora_weights
|
||||
if cfg.peft_use_dora:
|
||||
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
|
||||
LOG.info("Initializing LoRA weights using dora. This might take longer.")
|
||||
if cfg.peft_use_rslora:
|
||||
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
|
||||
if cfg.peft_layer_replication:
|
||||
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
|
||||
if cfg.peft_trainable_token_indices:
|
||||
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
|
||||
if cfg.peft_ensure_weight_tying is not None:
|
||||
lora_config_kwargs["ensure_weight_tying"] = cfg.peft_ensure_weight_tying
|
||||
|
||||
# Determine the correct PEFT task type
|
||||
model_cls = type(model).__name__
|
||||
if "SequenceClassification" in model_cls:
|
||||
task_type = TaskType.SEQ_CLS
|
||||
elif "TokenClassification" in model_cls:
|
||||
task_type = TaskType.TOKEN_CLS
|
||||
else:
|
||||
task_type = TaskType.CAUSAL_LM
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
target_parameters=lora_target_parameters,
|
||||
layers_to_transform=cfg.peft_layers_to_transform,
|
||||
layers_pattern=cfg.peft_layers_pattern,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
exclude_modules=getattr(cfg, "lora_exclude_modules", None) or None,
|
||||
bias="none",
|
||||
task_type=task_type,
|
||||
**lora_config_kwargs,
|
||||
)
|
||||
lora_config = _build_peft_lora_config(model, cfg)
|
||||
|
||||
if config_only:
|
||||
return None, lora_config
|
||||
@@ -315,7 +332,7 @@ def load_lora(
|
||||
model_kwargs["autocast_adapter_dtype"] = cfg.peft_autocast_adapter_dtype
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - LoRA")
|
||||
LOG.debug("Loading pretrained PEFT adapter")
|
||||
if cfg.lora_on_cpu:
|
||||
model_kwargs["max_memory"] = {"cpu": "256GiB"}
|
||||
model_kwargs["device_map"] = {"": "cpu"}
|
||||
@@ -364,30 +381,60 @@ def load_adapter(
|
||||
cfg: DictDefault,
|
||||
adapter: str | None,
|
||||
inference: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel, PeftConfig | None]:
|
||||
config_only: bool = False,
|
||||
) -> tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]:
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
peft_model, lora_config = load_lora(model, cfg, inference=inference)
|
||||
peft_model, lora_config = load_lora(
|
||||
model, cfg, inference=inference, config_only=config_only
|
||||
)
|
||||
return peft_model, lora_config
|
||||
if adapter == "llama-adapter":
|
||||
if config_only:
|
||||
_, lora_config = load_llama_adapter(model, cfg, config_only=True)
|
||||
return None, lora_config
|
||||
peft_model, lora_config = load_llama_adapter(model, cfg)
|
||||
return peft_model, lora_config
|
||||
|
||||
raise NotImplementedError(f"{adapter} PEFT adapter not available")
|
||||
plugin_loaded = PLUGIN_MANAGER.load_adapter(
|
||||
model,
|
||||
cfg,
|
||||
inference=inference,
|
||||
config_only=config_only,
|
||||
)
|
||||
if plugin_loaded is not None:
|
||||
return plugin_loaded
|
||||
|
||||
adapter_capability = PLUGIN_MANAGER.get_adapter_capability(adapter)
|
||||
if adapter_capability and adapter_capability.lora_like:
|
||||
peft_model, lora_config = load_lora(
|
||||
model, cfg, inference=inference, config_only=config_only
|
||||
)
|
||||
return peft_model, lora_config
|
||||
|
||||
registered = sorted(PLUGIN_MANAGER.adapter_capabilities())
|
||||
registered_msg = ", ".join(registered) if registered else "none"
|
||||
raise NotImplementedError(
|
||||
f"Adapter '{adapter}' is not built in and was not registered by a plugin "
|
||||
f"with loader support. Registered plugin adapters: {registered_msg}"
|
||||
)
|
||||
|
||||
|
||||
def load_llama_adapter(
|
||||
model: PreTrainedModel, cfg: DictDefault
|
||||
) -> tuple[PeftModel | PeftMixedModel, PeftConfig]:
|
||||
model: PreTrainedModel, cfg: DictDefault, config_only: bool = False
|
||||
) -> tuple[PeftModel | PeftMixedModel | None, PeftConfig]:
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
adapter_len=cfg.peft_adapter.len, # prompt length (K)
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if config_only:
|
||||
return None, peft_config
|
||||
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretrained PEFT - llama_adapter")
|
||||
peft_model = PeftModel.from_pretrained(
|
||||
|
||||
@@ -39,7 +39,7 @@ from transformers.integrations.deepspeed import (
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.adapter import load_adapter, load_lora
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.loaders.constants import MULTIMODAL_AUTO_MODEL_MAPPING
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
from axolotl.loaders.utils import (
|
||||
@@ -386,8 +386,12 @@ class ModelLoader:
|
||||
and self.cfg.rl in [RLType.DPO, RLType.IPO, RLType.KTO]
|
||||
and not self.cfg.merge_lora
|
||||
):
|
||||
_, lora_config = load_lora(
|
||||
self.model, self.cfg, inference=False, config_only=True
|
||||
_, lora_config = load_adapter(
|
||||
self.model,
|
||||
self.cfg,
|
||||
self.cfg.adapter,
|
||||
inference=False,
|
||||
config_only=True,
|
||||
)
|
||||
else:
|
||||
self.model, lora_config = load_adapter(
|
||||
|
||||
@@ -285,7 +285,9 @@ def save_trained_model(
|
||||
)
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
if hasattr(model, "merge_and_unload") and not (
|
||||
cfg.load_in_4bit or cfg.load_in_8bit
|
||||
):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
||||
|
||||
@@ -38,10 +38,10 @@ class LoraConfig(BaseModel):
|
||||
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
|
||||
)
|
||||
|
||||
adapter: Literal["lora", "qlora", "llama-adapter"] | None = Field(
|
||||
adapter: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you want to use 'lora', 'qlora', or 'llama-adapter', or leave blank to train all parameters in original model"
|
||||
"description": "If you want to use a built-in or plugin adapter, or leave blank to train all parameters in original model"
|
||||
},
|
||||
)
|
||||
lora_model_dir: str | None = Field(
|
||||
@@ -174,6 +174,16 @@ class LoraConfig(BaseModel):
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
adapter = data.get("adapter")
|
||||
if adapter and adapter not in ("lora", "qlora", "llama-adapter"):
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
if not plugin_manager.supports_adapter(adapter):
|
||||
raise ValueError(
|
||||
f"Adapter '{adapter}' is not built in and was not registered by "
|
||||
"a plugin. Add the plugin that provides this adapter to `plugins:`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@@ -1478,8 +1478,19 @@ class ComplexValidationMixin:
|
||||
if self.relora:
|
||||
if not self.jagged_restart_steps:
|
||||
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
|
||||
if self.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
adapter_supports_relora = self.adapter in ("lora", "qlora")
|
||||
if self.adapter and not adapter_supports_relora:
|
||||
from axolotl.integrations.base import PluginManager
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
adapter_supports_relora = plugin_manager.adapter_supports_relora(
|
||||
self.adapter
|
||||
)
|
||||
if not adapter_supports_relora:
|
||||
raise ValueError(
|
||||
"cfg.adapter must support ReLoRA to use ReLoRA restart semantics"
|
||||
)
|
||||
|
||||
if self.fsdp or self.fsdp_config:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
160
tests/integrations/mora/test_mora.py
Normal file
160
tests/integrations/mora/test_mora.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Integration tests for the MoRA / ReMoRA adapter path."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.mora import plugin as mora_plugin
|
||||
from axolotl.loaders import adapter as adapter_module
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMoraAdapterLoading:
|
||||
"""MoRA adapter selection and config wiring."""
|
||||
|
||||
def test_load_adapter_uses_plugin_lora_like_registration(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_load_lora(*args, **kwargs):
|
||||
calls.append((args, kwargs))
|
||||
return args[0], "adapter-config"
|
||||
|
||||
monkeypatch.setattr(adapter_module, "load_lora", fake_load_lora)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert config == "adapter-config"
|
||||
assert calls[0][1]["config_only"] is False
|
||||
|
||||
def test_mora_plugin_raises_when_peft_missing_support(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ImportError, match="MoRA support"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_rejects_quantized_base_model(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"load_in_4bit": True,
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="full-precision base model"):
|
||||
load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
def test_mora_plugin_builds_mora_config_when_supported(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
fake_model = SimpleNamespace(print_trainable_parameters=Mock())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module, "get_peft_model", Mock(return_value=fake_model)
|
||||
)
|
||||
|
||||
_, config = load_adapter(model, cfg, "mora", config_only=True)
|
||||
|
||||
assert captured["use_mora"] is True
|
||||
assert captured["mora_type"] == 6
|
||||
assert captured["task_type"].name == "CAUSAL_LM"
|
||||
assert config is not None
|
||||
assert config.use_mora is True
|
||||
assert config.mora_type == 6
|
||||
|
||||
def test_mora_plugin_uses_lora_model_dir_resume_path(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"mora": {"use_mora": True, "mora_type": "rope"},
|
||||
"lora_model_dir": "adapter-checkpoint",
|
||||
"lora_on_cpu": False,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
class FakePeftModel:
|
||||
def print_trainable_parameters(self):
|
||||
pass
|
||||
|
||||
def named_parameters(self):
|
||||
return []
|
||||
|
||||
from_pretrained = Mock(return_value=FakePeftModel())
|
||||
PluginManager.get_instance().plugins["axolotl.integrations.mora.MoraPlugin"] = (
|
||||
mora_plugin.MoraPlugin()
|
||||
)
|
||||
monkeypatch.setattr(mora_plugin, "_peft_supports_mora", lambda: True)
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(
|
||||
adapter_module.PeftModel,
|
||||
"from_pretrained",
|
||||
from_pretrained,
|
||||
)
|
||||
|
||||
peft_model, config = load_adapter(model, cfg, "mora")
|
||||
|
||||
assert isinstance(peft_model, FakePeftModel)
|
||||
assert config.use_mora is True
|
||||
from_pretrained.assert_called_once()
|
||||
assert from_pretrained.call_args.args[:2] == (model, "adapter-checkpoint")
|
||||
assert from_pretrained.call_args.kwargs["is_trainable"] is True
|
||||
73
tests/integrations/test_adapter_plugin_registry.py
Normal file
73
tests/integrations/test_adapter_plugin_registry.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Core adapter plugin registry tests."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import AdapterCapabilities, BasePlugin, PluginManager
|
||||
from axolotl.loaders import adapter as adapter_module
|
||||
from axolotl.loaders.adapter import load_adapter
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class FakeAdapterPlugin(BasePlugin):
|
||||
def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
|
||||
return [AdapterCapabilities(name="fake-adapter", lora_like=True, relora=True)]
|
||||
|
||||
def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
|
||||
if cfg.adapter != "fake-adapter":
|
||||
return {}
|
||||
return {"fake_kwarg": "from-plugin"}
|
||||
|
||||
|
||||
class TestAdapterPluginRegistry:
|
||||
def test_lora_like_plugin_adapter_contributes_peft_kwargs(self, monkeypatch):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"adapter": "fake-adapter",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
}
|
||||
)
|
||||
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
|
||||
captured = {}
|
||||
|
||||
class FakeLoraConfig:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
monkeypatch.setattr(adapter_module, "LoraConfig", FakeLoraConfig)
|
||||
monkeypatch.setattr(adapter_module, "get_peft_model", Mock())
|
||||
|
||||
_, config = load_adapter(model, cfg, "fake-adapter", config_only=True)
|
||||
|
||||
assert config is not None
|
||||
assert captured["fake_kwarg"] == "from-plugin"
|
||||
assert captured["task_type"].name == "CAUSAL_LM"
|
||||
|
||||
def test_unknown_adapter_error_mentions_plugin_registry(self):
|
||||
model = torch.nn.Linear(4, 4)
|
||||
cfg = DictDefault({"adapter": "missing-adapter"})
|
||||
|
||||
with pytest.raises(NotImplementedError, match="registered by a plugin"):
|
||||
load_adapter(model, cfg, "missing-adapter")
|
||||
|
||||
def test_relora_accepts_plugin_adapter_capability(self, min_base_cfg):
|
||||
PluginManager.get_instance().plugins["fake"] = FakeAdapterPlugin()
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "fake-adapter",
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 100,
|
||||
}
|
||||
)
|
||||
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.adapter == "fake-adapter"
|
||||
assert validated.relora is True
|
||||
100
tests/utils/schemas/validation/mora/test_mora_validation.py
Normal file
100
tests/utils/schemas/validation/mora/test_mora_validation.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Validation tests for the MoRA / ReMoRA integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.integrations.mora import MoraType
|
||||
from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMoraValidation:
|
||||
"""MoRA-specific config validation."""
|
||||
|
||||
def test_mora_block_round_trips(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.adapter == "mora"
|
||||
assert validated.mora.use_mora is True
|
||||
assert validated.mora.mora_type == MoraType.ROPE
|
||||
|
||||
def test_mora_type_accepts_legacy_supported_numbers(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": 1,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.mora.mora_type == MoraType.SHARING
|
||||
|
||||
def test_mora_rejects_unsupported_variant_numbers(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": 2,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
with pytest.raises(ValueError, match="mora_type"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_remora_uses_core_relora_fields(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 2000,
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
validated = validate_config(cfg)
|
||||
|
||||
assert validated.relora is True
|
||||
assert validated.jagged_restart_steps == 2000
|
||||
|
||||
def test_remora_still_requires_core_restart_steps(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "mora",
|
||||
"plugins": ["axolotl.integrations.mora.MoraPlugin"],
|
||||
"relora": True,
|
||||
"mora": {
|
||||
"use_mora": True,
|
||||
"mora_type": "rope",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
with pytest.raises(ValueError, match="jagged_restart_steps"):
|
||||
validate_config(cfg)
|
||||
Reference in New Issue
Block a user