Compare commits
1 Commits
sdpa-cp
...
jagged-res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe12aa79c8 |
@@ -8,6 +8,7 @@ from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.schedulers import (
|
||||
JaggedLRRestartScheduler,
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
@@ -112,7 +113,22 @@ class SchedulerMixin(Trainer):
|
||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
if self.args.jagged_restart_steps:
|
||||
warmup_steps = (
|
||||
self.args.jagged_restarts_warmup_steps or 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.jagged_restarts_anneal_steps or 1
|
||||
)
|
||||
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
self.lr_scheduler,
|
||||
self.args.jagged_restart_steps,
|
||||
warmup_steps,
|
||||
anneal_steps,
|
||||
)
|
||||
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
@@ -86,6 +86,22 @@ class AxolotlTrainingMixins:
|
||||
default=0.9,
|
||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||
)
|
||||
jagged_restart_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for jagged restarts"},
|
||||
)
|
||||
jagged_restarts_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
jagged_restarts_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||
},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
@@ -292,3 +293,47 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
||||
num_cycles=num_cycles,
|
||||
)
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
class JaggedLRRestartScheduler(LRScheduler):
|
||||
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
inner_schedule: LRScheduler,
|
||||
jagged_restarts_steps: int,
|
||||
jagged_restarts_warmup_steps: int,
|
||||
jagged_restarts_anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.inner_schedule = inner_schedule
|
||||
self.restarts_steps = jagged_restarts_steps
|
||||
self.warmup_steps = jagged_restarts_warmup_steps
|
||||
self.anneal_steps = jagged_restarts_anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
def get_lr(self) -> List[float]:
|
||||
self.inner_schedule.last_epoch = self.last_epoch
|
||||
|
||||
original: List[float] = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.restarts_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_restart_progress = step % self.restarts_steps
|
||||
if per_restart_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, per_restart_progress / self.warmup_steps)
|
||||
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
return original * scale
|
||||
|
||||
Reference in New Issue
Block a user