Compare commits
1 Commits
activeblue
...
jagged-res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe12aa79c8 |
@@ -8,6 +8,7 @@ from transformers.trainer import Trainer
|
|||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
|
JaggedLRRestartScheduler,
|
||||||
RexLR,
|
RexLR,
|
||||||
get_cosine_schedule_with_min_lr,
|
get_cosine_schedule_with_min_lr,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
@@ -112,7 +113,22 @@ class SchedulerMixin(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
if self.args.jagged_restart_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.jagged_restarts_warmup_steps or 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.jagged_restarts_anneal_steps or 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = JaggedLRRestartScheduler( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
self.lr_scheduler,
|
||||||
|
self.args.jagged_restart_steps,
|
||||||
|
warmup_steps,
|
||||||
|
anneal_steps,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|||||||
@@ -86,6 +86,22 @@ class AxolotlTrainingMixins:
|
|||||||
default=0.9,
|
default=0.9,
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
)
|
)
|
||||||
|
jagged_restart_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for jagged restarts"},
|
||||||
|
)
|
||||||
|
jagged_restarts_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "how many warmup steps to take after reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
jagged_restarts_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "how many anneal steps to take before reset for jagged restarts"
|
||||||
|
},
|
||||||
|
)
|
||||||
bench_split: Optional[str] = field(
|
bench_split: Optional[str] = field(
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
@@ -292,3 +293,47 @@ def get_cosine_schedule_with_warmup_decay_constant(
|
|||||||
num_cycles=num_cycles,
|
num_cycles=num_cycles,
|
||||||
)
|
)
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class JaggedLRRestartScheduler(LRScheduler):
|
||||||
|
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
inner_schedule: LRScheduler,
|
||||||
|
jagged_restarts_steps: int,
|
||||||
|
jagged_restarts_warmup_steps: int,
|
||||||
|
jagged_restarts_anneal_steps: int = 1,
|
||||||
|
min_lr_scale: float = 0.001,
|
||||||
|
) -> None:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.inner_schedule = inner_schedule
|
||||||
|
self.restarts_steps = jagged_restarts_steps
|
||||||
|
self.warmup_steps = jagged_restarts_warmup_steps
|
||||||
|
self.anneal_steps = jagged_restarts_anneal_steps
|
||||||
|
self.min_lr_scale = min_lr_scale
|
||||||
|
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||||
|
|
||||||
|
def get_lr(self) -> List[float]:
|
||||||
|
self.inner_schedule.last_epoch = self.last_epoch
|
||||||
|
|
||||||
|
original: List[float] = self.inner_schedule.get_lr()
|
||||||
|
step = self.last_epoch
|
||||||
|
|
||||||
|
if step < self.restarts_steps:
|
||||||
|
scale = 1
|
||||||
|
else:
|
||||||
|
per_restart_progress = step % self.restarts_steps
|
||||||
|
if per_restart_progress < self.warmup_steps:
|
||||||
|
cycle_t = min(1.0, per_restart_progress / self.warmup_steps)
|
||||||
|
elif per_restart_progress > (self.restarts_steps - self.anneal_steps):
|
||||||
|
cycle_t = min(
|
||||||
|
1.0,
|
||||||
|
(self.restarts_steps - per_restart_progress) / self.anneal_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cycle_t = 1
|
||||||
|
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||||
|
|
||||||
|
return original * scale
|
||||||
|
|||||||
Reference in New Issue
Block a user