diff --git a/_quarto.yml b/_quarto.yml index 250596d52..bfef13afb 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -59,7 +59,6 @@ quartodoc: - core.trainers.base - core.trainers.trl - core.trainers.mamba - - core.trainers.relora - core.trainers.dpo.trainer - core.trainers.grpo.trainer - core.trainers.grpo.sampler diff --git a/examples/llama-2/relora.yml b/examples/llama-2/relora.yml index b0e905340..fabdf0e0f 100644 --- a/examples/llama-2/relora.yml +++ b/examples/llama-2/relora.yml @@ -25,9 +25,12 @@ lora_alpha: 16 lora_dropout: 0.05 lora_target_linear: true -relora_steps: 150 -relora_warmup_ratio: 0.1 +relora: true +relora_prune_ratio: 0.9 relora_cpu_offload: false +jagged_restart_steps: 150 +jagged_restart_warmup_steps: 10 +jagged_restart_anneal_steps: false wandb_project: wandb_entity: diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 00cee35a7..b461e9009 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -19,7 +19,6 @@ from axolotl.core.trainers import ( AxolotlPRMTrainer, AxolotlRewardTrainer, AxolotlTrainer, - ReLoRATrainer, ) from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES @@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): def get_callbacks(self): callbacks = super().get_callbacks() - if self.cfg.relora_steps: + if self.cfg.relora: callbacks.append(ReLoRACallback(self.cfg)) if ( @@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_cls = plugin_manager.get_trainer_cls(self.cfg) if trainer_cls: return trainer_cls - if self.cfg.relora_steps: - return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer if self.cfg.reward_model: @@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): self.cfg.sample_packing_eff_est ) - if self.cfg.relora_steps: - training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps - training_arguments_kwargs["relora_warmup_steps"] = ( - self.cfg.relora_warmup_steps - ) - if self.cfg.relora_anneal_steps: - training_arguments_kwargs["relora_anneal_steps"] = ( - self.cfg.relora_anneal_steps - ) + if self.cfg.relora and self.cfg.jagged_restart_steps: if self.cfg.relora_prune_ratio: training_arguments_kwargs["relora_prune_ratio"] = ( self.cfg.relora_prune_ratio ) + if self.cfg.jagged_restart_steps: + training_arguments_kwargs["jagged_restart_steps"] = ( + self.cfg.jagged_restart_steps + ) + if self.cfg.jagged_restart_warmup_steps: + training_arguments_kwargs["jagged_restart_warmup_steps"] = ( + self.cfg.jagged_restart_warmup_steps + ) + if self.cfg.jagged_restart_anneal_steps: + training_arguments_kwargs["jagged_restart_anneal_steps"] = ( + self.cfg.jagged_restart_anneal_steps + ) + if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers training_arguments_kwargs["lisa_step_interval"] = ( diff --git a/src/axolotl/core/trainers/__init__.py b/src/axolotl/core/trainers/__init__.py index 46b9b15ed..5f97e387a 100644 --- a/src/axolotl/core/trainers/__init__.py +++ b/src/axolotl/core/trainers/__init__.py @@ -7,7 +7,6 @@ from .base import AxolotlTrainer from .dpo.trainer import AxolotlDPOTrainer from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer from .mamba import AxolotlMambaTrainer -from .relora import ReLoRATrainer from .trl import ( AxolotlCPOTrainer, AxolotlKTOTrainer, diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index 90070ab78..399bf5947 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -7,6 +7,7 @@ from transformers.trainer import Trainer from axolotl.integrations.base import PluginManager from axolotl.utils.logging import get_logger from axolotl.utils.schedulers import ( + JaggedLRRestartScheduler, RexLR, get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, @@ -113,7 +114,7 @@ class SchedulerMixin(Trainer): min_lr_ratio=self.args.cosine_min_lr_ratio, ) else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) + super().create_scheduler(num_training_steps, optimizer=optimizer) else: if use_cosine_quadratic: LOG.warning( @@ -123,4 +124,22 @@ class SchedulerMixin(Trainer): LOG.warning( "axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + if self.args.jagged_restart_steps: + warmup_steps = ( + self.args.jagged_restart_warmup_steps or 10 + ) + anneal_steps = ( + self.args.jagged_restart_anneal_steps or 1 + ) + if not self.lr_scheduler: + super().create_scheduler(num_training_steps, optimizer) + self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init + optimizer, + self.lr_scheduler, + self.args.jagged_restart_steps, + warmup_steps, + anneal_steps, + min_lr_scale=self.args.cosine_min_lr_ratio or 0.001, + ) + return self.lr_scheduler # type: ignore diff --git a/src/axolotl/core/trainers/relora.py b/src/axolotl/core/trainers/relora.py deleted file mode 100644 index 890278f49..000000000 --- a/src/axolotl/core/trainers/relora.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Module for ReLoRA trainer""" - -import torch -from torch.optim.lr_scheduler import LRScheduler - -from axolotl.core.trainers.base import AxolotlTrainer -from axolotl.monkeypatch.relora import ReLoRAScheduler - - -class ReLoRATrainer(AxolotlTrainer): - """Trainer subclass that uses the `OneCycleLR` scheduler""" - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: torch.optim.Optimizer | None = None, - ) -> LRScheduler: - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler: LRScheduler = super().create_scheduler( - num_training_steps, optimizer - ) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( # type: ignore - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler # type: ignore - - return self.lr_scheduler # type: ignore diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py index 4b74676ce..66649deef 100644 --- a/src/axolotl/core/training_args_base.py +++ b/src/axolotl/core/training_args_base.py @@ -82,18 +82,26 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "how often to reset for ReLoRA"}, ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) relora_prune_ratio: Optional[float] = field( default=0.9, metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, ) + jagged_restart_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for jagged restarts"}, + ) + jagged_restart_warmup_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restart_anneal_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many anneal steps to take before reset for jagged restarts" + }, + ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} ) diff --git a/src/axolotl/monkeypatch/relora.py b/src/axolotl/monkeypatch/relora.py index 5b7418e39..0028a0cf6 100644 --- a/src/axolotl/monkeypatch/relora.py +++ b/src/axolotl/monkeypatch/relora.py @@ -6,7 +6,7 @@ import os.path import shutil from functools import partial from pathlib import Path -from typing import Dict, List, Sequence, Union +from typing import Dict, List, Union import bitsandbytes as bnb import peft @@ -14,8 +14,6 @@ import safetensors.torch as st import torch from huggingface_hub import snapshot_download from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.optim.lr_scheduler import LRScheduler -from torch.optim.optimizer import Optimizer from transformers import ( TrainerCallback, TrainerControl, @@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback): """Callback to merge LoRA weights into the base model and save full-weight checkpoints""" def __init__(self, cfg: DictDefault): - self.relora_steps = cfg.relora_steps + self.relora_steps = cfg.jagged_restart_steps self.cpu_offload = cfg.relora_cpu_offload self.quantized = cfg.load_in_4bit or cfg.load_in_8bit self.last_full_model = cfg.base_model @@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback): return control -class ReLoRAScheduler(LRScheduler): - """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" - - def __init__( - self, - optimizer: Optimizer, - inner_schedule: LRScheduler, - relora_steps: int, - warmup_steps: int, - anneal_steps: int = 1, - min_lr_scale: float = 0.001, - ) -> None: - self.inner_schedule = inner_schedule - self.relora_steps = relora_steps - self.warmup_steps = warmup_steps - self.anneal_steps = anneal_steps - self.min_lr_scale = min_lr_scale - super().__init__(optimizer, inner_schedule.last_epoch) - - def get_lr(self) -> float: - self.inner_schedule.last_epoch = self.last_epoch - - original = self.inner_schedule.get_lr() - step = self.last_epoch - - if step < self.relora_steps - self.warmup_steps: - scale = 1 - else: - per_relora_progress = step % self.relora_steps - if per_relora_progress < self.warmup_steps: - cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps) - elif per_relora_progress > (self.relora_steps - self.anneal_steps): - cycle_t = min( - 1.0, - (self.relora_steps - per_relora_progress) / self.anneal_steps, - ) - else: - cycle_t = 1 - scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale - - if isinstance(original, Sequence): - return [lr * scale for lr in original] - return original * scale - - def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]: model_name = "model.safetensors" if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists( diff --git a/src/axolotl/train.py b/src/axolotl/train.py index d57cb463e..b507c2c7b 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -267,7 +267,7 @@ def save_trained_model( "your model weights with `axolotl quantize`." ) # Handle ReLoRA early return case - if cfg.relora_steps: + if cfg.relora: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload() else: diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b550ac02c..b9d09ad9c 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -2,6 +2,7 @@ import math from functools import partial +from typing import Sequence from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant( num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class JaggedLRRestartScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + jagged_restart_steps: int, + jagged_restart_warmup_steps: int, + jagged_restart_anneal_steps: int = 1, + min_lr_scale: float = 0.001, + ) -> None: + # pylint: disable=duplicate-code + self.inner_schedule = inner_schedule + self.restarts_steps = jagged_restart_steps + self.warmup_steps = jagged_restart_warmup_steps + self.anneal_steps = jagged_restart_anneal_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch) + + def get_lr(self) -> float | Sequence[float]: + self.inner_schedule.last_epoch = self.last_epoch + + original = self.inner_schedule.get_lr() + step = self.last_epoch + + if step < self.restarts_steps - self.anneal_steps: + scale = 1 + else: + per_restart_progress = step % self.restarts_steps + if per_restart_progress < self.warmup_steps: + cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps) + elif per_restart_progress > (self.restarts_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.restarts_steps - per_restart_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + if isinstance(original, Sequence): + return [lr * scale for lr in original] + + return original * scale diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 0afeaa2a8..f8746692c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import ( from axolotl.utils.schemas.multimodal import MultiModalConfig from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig from axolotl.utils.schemas.quantization import PTQConfig, QATConfig -from axolotl.utils.schemas.training import HyperparametersConfig +from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig @@ -57,6 +57,7 @@ class AxolotlInputConfig( ModelOutputConfig, LoraConfig, ReLoRAConfig, + JaggedLRConfig, HyperparametersConfig, WandbConfig, MLFlowConfig, diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 4b31ce018..341397b42 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -187,18 +187,10 @@ class LoraConfig(BaseModel): class ReLoRAConfig(BaseModel): """ReLoRA configuration subset""" - relora_steps: int | None = Field( - default=None, - json_schema_extra={"description": "Number of steps per ReLoRA restart"}, - ) - relora_warmup_steps: int | None = Field( - default=None, - json_schema_extra={"description": "Number of per-restart warmup steps"}, - ) - relora_anneal_steps: int | None = Field( + relora: bool | None = Field( default=None, json_schema_extra={ - "description": "Number of anneal steps for each relora cycle" + "description": "Whether to use ReLoRA. Use with jagged_restart_*steps options." }, ) relora_prune_ratio: float | None = Field( diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index 4d88cc9e6..6ee863397 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel): if learning_rate and isinstance(learning_rate, str): learning_rate = float(learning_rate) return learning_rate + + +class JaggedLRConfig(BaseModel): + """JaggedLR configuration subset, can be used w/ ReLoRA training""" + + jagged_restart_steps: int | None = Field( + default=None, + json_schema_extra={"description": "how often to reset for jagged restarts"}, + ) + jagged_restart_warmup_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restart_anneal_steps: int | None = Field( + default=None, + json_schema_extra={ + "description": "how many anneal steps to take before reset for jagged restarts" + }, + ) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 9ca33f456..063690c59 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1164,7 +1164,9 @@ class ComplexValidationMixin: @model_validator(mode="after") def check_relora(self): - if self.relora_steps: + if self.relora: + if not self.jagged_restart_steps: + raise ValueError("jagged_restart_steps must be set to use ReLoRA") if self.adapter not in ("lora", "qlora"): raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") diff --git a/tests/e2e/solo/test_relora_llama.py b/tests/e2e/solo/test_relora_llama.py index f6fcad841..b399b4680 100644 --- a/tests/e2e/solo/test_relora_llama.py +++ b/tests/e2e/solo/test_relora_llama.py @@ -34,9 +34,10 @@ class TestReLoraLlama(unittest.TestCase): "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_modules": ["q_proj", "v_proj"], - "relora_steps": 50, - "relora_warmup_steps": 10, - "relora_anneal_steps": 10, + "relora": True, + "jagged_restart_steps": 50, + "jagged_restart_warmup_steps": 10, + "jagged_restart_anneal_steps": 10, "relora_prune_ratio": 0.9, "relora_cpu_offload": True, "val_set_size": 0.0,