diff --git a/src/axolotl/core/trainers/mixins/scheduler.py b/src/axolotl/core/trainers/mixins/scheduler.py index 0c36f9f95..0abeffaee 100644 --- a/src/axolotl/core/trainers/mixins/scheduler.py +++ b/src/axolotl/core/trainers/mixins/scheduler.py @@ -8,6 +8,7 @@ from transformers.trainer import Trainer from axolotl.integrations.base import PluginManager from axolotl.utils.schedulers import ( + JaggedLRRestartScheduler, RexLR, get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, @@ -112,7 +113,22 @@ class SchedulerMixin(Trainer): min_lr_ratio=self.args.cosine_min_lr_ratio, ) else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) + super().create_scheduler(num_training_steps, optimizer=optimizer) + if self.args.jagged_restart_steps: + warmup_steps = ( + self.args.jagged_restarts_warmup_steps or 10 + ) + anneal_steps = ( + self.args.jagged_restarts_anneal_steps or 1 + ) + self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init + optimizer, + self.lr_scheduler, + self.args.jagged_restart_steps, + warmup_steps, + anneal_steps, + ) + else: if use_cosine_quadratic: LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 0b14e7661..98b4a14f5 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -86,6 +86,22 @@ class AxolotlTrainingMixins: default=0.9, metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, ) + jagged_restart_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for jagged restarts"}, + ) + jagged_restarts_warmup_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many warmup steps to take after reset for jagged restarts" + }, + ) + jagged_restarts_anneal_steps: Optional[int] = field( + default=None, + metadata={ + "help": "how many anneal steps to take before reset for jagged restarts" + }, + ) bench_split: Optional[str] = field( default="eval", metadata={"help": "The benchmark split to run on"} ) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index b550ac02c..f6a56e793 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -2,6 +2,7 @@ import math from functools import partial +from typing import List from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler @@ -292,3 +293,47 @@ def get_cosine_schedule_with_warmup_decay_constant( num_cycles=num_cycles, ) return LambdaLR(optimizer, lr_lambda, last_epoch) + + +class JaggedLRRestartScheduler(LRScheduler): + """Wraps another scheduler to apply per-lora-restart learning rate warmups.""" + + def __init__( + self, + optimizer: Optimizer, + inner_schedule: LRScheduler, + jagged_restarts_steps: int, + jagged_restarts_warmup_steps: int, + jagged_restarts_anneal_steps: int = 1, + min_lr_scale: float = 0.001, + ) -> None: + # pylint: disable=duplicate-code + self.inner_schedule = inner_schedule + self.restarts_steps = jagged_restarts_steps + self.warmup_steps = jagged_restarts_warmup_steps + self.anneal_steps = jagged_restarts_anneal_steps + self.min_lr_scale = min_lr_scale + super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose) + + def get_lr(self) -> List[float]: + self.inner_schedule.last_epoch = self.last_epoch + + original: List[float] = self.inner_schedule.get_lr() + step = self.last_epoch + + if step < self.restarts_steps: + scale = 1 + else: + per_restart_progress = step % self.restarts_steps + if per_restart_progress < self.warmup_steps: + cycle_t = min(1.0, per_restart_progress / self.warmup_steps) + elif per_restart_progress > (self.restarts_steps - self.anneal_steps): + cycle_t = min( + 1.0, + (self.restarts_steps - per_restart_progress) / self.anneal_steps, + ) + else: + cycle_t = 1 + scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale + + return original * scale