jagged lr restart scheudler (#1680) [skip ci]

* jagged lr restart scheudler

var name fix
make sure to create scheduler first

* wire things together

* more fixes

* fix for nesting scheduler and first anneal phase

* no need for relora trainer anymore since we've generalized the relora scheduler

* remove redundant relora scheduler and lint

* update relora e2e test for updated params

* need restart steps for relora test

* update quarto docs for dropped relora trainer

* update example yaml

* drop verbose arg

* min lr scale support for jagged lr

* don't let min_lr be nonetype

* cleanup args
This commit is contained in:
Wing Lian
2025-07-31 13:50:03 -04:00
committed by GitHub
parent 32a7890231
commit 7b68dfafd7
15 changed files with 139 additions and 137 deletions

View File

@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
AxolotlPRMTrainer,
AxolotlRewardTrainer,
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
@@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def get_callbacks(self):
callbacks = super().get_callbacks()
if self.cfg.relora_steps:
if self.cfg.relora:
callbacks.append(ReLoRACallback(self.cfg))
if (
@@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.reward_model:
@@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora and self.cfg.jagged_restart_steps:
if self.cfg.relora_prune_ratio:
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.jagged_restart_steps:
training_arguments_kwargs["jagged_restart_steps"] = (
self.cfg.jagged_restart_steps
)
if self.cfg.jagged_restart_warmup_steps:
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
self.cfg.jagged_restart_warmup_steps
)
if self.cfg.jagged_restart_anneal_steps:
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
self.cfg.jagged_restart_anneal_steps
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs["lisa_step_interval"] = (

View File

@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
from .dpo.trainer import AxolotlDPOTrainer
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
from .mamba import AxolotlMambaTrainer
from .relora import ReLoRATrainer
from .trl import (
AxolotlCPOTrainer,
AxolotlKTOTrainer,

View File

@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
JaggedLRRestartScheduler,
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning(
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restart_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restart_anneal_steps or 1
)
if not self.lr_scheduler:
super().create_scheduler(num_training_steps, optimizer)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
)
return self.lr_scheduler # type: ignore

View File

@@ -1,46 +0,0 @@
"""Module for ReLoRA trainer"""
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.core.trainers.base import AxolotlTrainer
from axolotl.monkeypatch.relora import ReLoRAScheduler
class ReLoRATrainer(AxolotlTrainer):
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: torch.optim.Optimizer | None = None,
) -> LRScheduler:
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler: LRScheduler = super().create_scheduler(
num_training_steps, optimizer
)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler( # type: ignore
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler # type: ignore
return self.lr_scheduler # type: ignore

View File

@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restart_warmup_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restart_anneal_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many anneal steps to take before reset for jagged restarts"
},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)

View File

@@ -6,7 +6,7 @@ import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence, Union
from typing import Dict, List, Union
import bitsandbytes as bnb
import peft
@@ -14,8 +14,6 @@ import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
TrainerCallback,
TrainerControl,
@@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback):
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
def __init__(self, cfg: DictDefault):
self.relora_steps = cfg.relora_steps
self.relora_steps = cfg.jagged_restart_steps
self.cpu_offload = cfg.relora_cpu_offload
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
self.last_full_model = cfg.base_model
@@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback):
return control
class ReLoRAScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps - self.warmup_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
model_name = "model.safetensors"
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(

View File

@@ -267,7 +267,7 @@ def save_trained_model(
"your model weights with `axolotl quantize`."
)
# Handle ReLoRA early return case
if cfg.relora_steps:
if cfg.relora:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload()
else:

View File

@@ -2,6 +2,7 @@
import math
from functools import partial
from typing import Sequence
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant(
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restart_steps: int,
jagged_restart_warmup_steps: int,
jagged_restart_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
# pylint: disable=duplicate-code
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restart_steps
self.warmup_steps = jagged_restart_warmup_steps
self.anneal_steps = jagged_restart_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float | Sequence[float]:
self.inner_schedule.last_epoch = self.last_epoch
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps - self.anneal_steps:
scale = 1
else:
per_restart_progress = step % self.restarts_steps
if per_restart_progress < self.warmup_steps:
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
return [lr * scale for lr in original]
return original * scale

View File

@@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import (
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.validation import ValidationMixin
from axolotl.utils.schemas.vllm import VllmConfig
@@ -57,6 +57,7 @@ class AxolotlInputConfig(
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
JaggedLRConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,

View File

@@ -187,18 +187,10 @@ class LoraConfig(BaseModel):
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
)
relora_warmup_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of per-restart warmup steps"},
)
relora_anneal_steps: int | None = Field(
relora: bool | None = Field(
default=None,
json_schema_extra={
"description": "Number of anneal steps for each relora cycle"
"description": "Whether to use ReLoRA. Use with jagged_restart_*steps options."
},
)
relora_prune_ratio: float | None = Field(

View File

@@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel):
if learning_rate and isinstance(learning_rate, str):
learning_rate = float(learning_rate)
return learning_rate
class JaggedLRConfig(BaseModel):
"""JaggedLR configuration subset, can be used w/ ReLoRA training"""
jagged_restart_steps: int | None = Field(
default=None,
json_schema_extra={"description": "how often to reset for jagged restarts"},
)
jagged_restart_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restart_anneal_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "how many anneal steps to take before reset for jagged restarts"
},
)

View File

@@ -1164,7 +1164,9 @@ class ComplexValidationMixin:
@model_validator(mode="after")
def check_relora(self):
if self.relora_steps:
if self.relora:
if not self.jagged_restart_steps:
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")