jagged lr restart scheudler (#1680) [skip ci]
* jagged lr restart scheudler var name fix make sure to create scheduler first * wire things together * more fixes * fix for nesting scheduler and first anneal phase * no need for relora trainer anymore since we've generalized the relora scheduler * remove redundant relora scheduler and lint * update relora e2e test for updated params * need restart steps for relora test * update quarto docs for dropped relora trainer * update example yaml * drop verbose arg * min lr scale support for jagged lr * don't let min_lr be nonetype * cleanup args
This commit is contained in:
@@ -19,7 +19,6 @@ from axolotl.core.trainers import (
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
@@ -58,7 +57,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
if self.cfg.relora:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
@@ -131,8 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if trainer_cls:
|
||||
return trainer_cls
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
@@ -271,20 +268,25 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora and self.cfg.jagged_restart_steps:
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.jagged_restart_steps:
|
||||
training_arguments_kwargs["jagged_restart_steps"] = (
|
||||
self.cfg.jagged_restart_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_warmup_steps:
|
||||
training_arguments_kwargs["jagged_restart_warmup_steps"] = (
|
||||
self.cfg.jagged_restart_warmup_steps
|
||||
)
|
||||
if self.cfg.jagged_restart_anneal_steps:
|
||||
training_arguments_kwargs["jagged_restart_anneal_steps"] = (
|
||||
self.cfg.jagged_restart_anneal_steps
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
|
||||
@@ -7,7 +7,6 @@ from .base import AxolotlTrainer
|
||||
from .dpo.trainer import AxolotlDPOTrainer
|
||||
from .grpo.trainer import AxolotlGRPOSequenceParallelTrainer, AxolotlGRPOTrainer
|
||||
from .mamba import AxolotlMambaTrainer
|
||||
from .relora import ReLoRATrainer
|
||||
from .trl import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
|
||||
@@ -7,6 +7,7 @@ from transformers.trainer import Trainer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schedulers import (
|
||||
JaggedLRRestartScheduler,
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -113,7 +114,7 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning(
|
||||
@@ -123,4 +124,22 @@ class SchedulerMixin(Trainer):
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restart_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restart_anneal_steps or 1
|
||||
)
|
||||
if not self.lr_scheduler:
|
||||
super().create_scheduler(num_training_steps, optimizer)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
min_lr_scale=self.args.cosine_min_lr_ratio or 0.001,
|
||||
)
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Module for ReLoRA trainer"""
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.core.trainers.base import AxolotlTrainer
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||
|
||||
tag_names = ["axolotl", "relora"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: torch.optim.Optimizer | None = None,
|
||||
) -> LRScheduler:
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler: LRScheduler = super().create_scheduler(
|
||||
num_training_steps, optimizer
|
||||
)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler( # type: ignore
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler # type: ignore
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
@@ -82,18 +82,26 @@ class AxolotlTrainingMixins:
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_prune_ratio: Optional[float] = field(
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restart_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restart_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
@@ -14,8 +14,6 @@ import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
@@ -84,7 +82,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
||||
|
||||
def __init__(self, cfg: DictDefault):
|
||||
self.relora_steps = cfg.relora_steps
|
||||
self.relora_steps = cfg.jagged_restart_steps
|
||||
self.cpu_offload = cfg.relora_cpu_offload
|
||||
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
||||
self.last_full_model = cfg.base_model
|
||||
@@ -255,51 +253,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
return control
|
||||
|
||||
|
||||
class ReLoRAScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||
|
||||
def get_lr(self) -> float:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps - self.warmup_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.relora_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
return original * scale
|
||||
|
||||
|
||||
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
model_name = "model.safetensors"
|
||||
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
||||
|
||||
@@ -267,7 +267,7 @@ def save_trained_model(
|
||||
"your model weights with `axolotl quantize`."
|
||||
)
|
||||
# Handle ReLoRA early return case
|
||||
if cfg.relora_steps:
|
||||
if cfg.relora:
|
||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||
model = model.merge_and_unload()
|
||||
else:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
@@ -292,3 +293,50 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class JaggedLRRestartScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
jagged_restart_steps: int,
|
||||
jagged_restart_warmup_steps: int,
|
||||
jagged_restart_anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.inner_schedule = inner_schedule
|
||||
self.restarts_steps = jagged_restart_steps
|
||||
self.warmup_steps = jagged_restart_warmup_steps
|
||||
self.anneal_steps = jagged_restart_anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch)
|
||||
|
||||
def get_lr(self) -> float | Sequence[float]:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.restarts_steps - self.anneal_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_restart_progress = step % self.restarts_steps
|
||||
if per_restart_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_restart_progress) / self.warmup_steps)
|
||||
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
return [lr * scale for lr in original]
|
||||
|
||||
return original * scale
|
||||
|
||||
@@ -43,7 +43,7 @@ from axolotl.utils.schemas.model import (
|
||||
from axolotl.utils.schemas.multimodal import MultiModalConfig
|
||||
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
|
||||
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.training import HyperparametersConfig, JaggedLRConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.validation import ValidationMixin
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
@@ -57,6 +57,7 @@ class AxolotlInputConfig(
|
||||
ModelOutputConfig,
|
||||
LoraConfig,
|
||||
ReLoRAConfig,
|
||||
JaggedLRConfig,
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
|
||||
@@ -187,18 +187,10 @@ class LoraConfig(BaseModel):
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
relora_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
|
||||
)
|
||||
relora_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of per-restart warmup steps"},
|
||||
)
|
||||
relora_anneal_steps: int | None = Field(
|
||||
relora: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of anneal steps for each relora cycle"
|
||||
"description": "Whether to use ReLoRA. Use with jagged_restart_*steps options."
|
||||
},
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
|
||||
@@ -160,3 +160,24 @@ class HyperparametersConfig(BaseModel):
|
||||
if learning_rate and isinstance(learning_rate, str):
|
||||
learning_rate = float(learning_rate)
|
||||
return learning_rate
|
||||
|
||||
|
||||
class JaggedLRConfig(BaseModel):
|
||||
"""JaggedLR configuration subset, can be used w/ ReLoRA training"""
|
||||
|
||||
jagged_restart_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restart_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restart_anneal_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1164,7 +1164,9 @@ class ComplexValidationMixin:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_relora(self):
|
||||
if self.relora_steps:
|
||||
if self.relora:
|
||||
if not self.jagged_restart_steps:
|
||||
raise ValueError("jagged_restart_steps must be set to use ReLoRA")
|
||||
if self.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user