Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
fe12aa79c8 jagged lr restart scheduler 2025-05-15 20:20:59 -04:00
3 changed files with 78 additions and 1 deletions

View File

@@ -8,6 +8,7 @@ from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.schedulers import (
JaggedLRRestartScheduler,
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
@@ -112,7 +113,22 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
super().create_scheduler(num_training_steps, optimizer=optimizer)
if self.args.jagged_restart_steps:
warmup_steps = (
self.args.jagged_restarts_warmup_steps or 10
)
anneal_steps = (
self.args.jagged_restarts_anneal_steps or 1
)
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
optimizer,
self.lr_scheduler,
self.args.jagged_restart_steps,
warmup_steps,
anneal_steps,
)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")

View File

@@ -86,6 +86,22 @@ class AxolotlTrainingMixins:
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
jagged_restart_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for jagged restarts"},
)
jagged_restarts_warmup_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many warmup steps to take after reset for jagged restarts"
},
)
jagged_restarts_anneal_steps: Optional[int] = field(
default=None,
metadata={
"help": "how many anneal steps to take before reset for jagged restarts"
},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)

View File

@@ -2,6 +2,7 @@
import math
from functools import partial
from typing import List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
@@ -292,3 +293,47 @@ def get_cosine_schedule_with_warmup_decay_constant(
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
class JaggedLRRestartScheduler(LRScheduler):
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
def __init__(
self,
optimizer: Optimizer,
inner_schedule: LRScheduler,
jagged_restarts_steps: int,
jagged_restarts_warmup_steps: int,
jagged_restarts_anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
# pylint: disable=duplicate-code
self.inner_schedule = inner_schedule
self.restarts_steps = jagged_restarts_steps
self.warmup_steps = jagged_restarts_warmup_steps
self.anneal_steps = jagged_restarts_anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
def get_lr(self) -> List[float]:
self.inner_schedule.last_epoch = self.last_epoch
original: List[float] = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.restarts_steps:
scale = 1
else:
per_restart_progress = step % self.restarts_steps
if per_restart_progress < self.warmup_steps:
cycle_t = min(1.0, per_restart_progress / self.warmup_steps)
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
)
else:
cycle_t = 1
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
return original * scale