Add REX LR Scheduler (#2380)
* Update trainer_builder.py * Update base.py * Update __init__.py * Update base.py * Update base.py * Update config.qmd * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * lint * lint * lint * lint * lint * lint * Update base.py * Update base.py * lint * Update base.py * Update base.py * Move RexLR to `schedulers.py` * Remove RexLR from `base.py` * Fix tooltip formatting * lint * Create test_schedulers.py * Use a default optimizer in test * lint * lint * Add `warmup_steps` and `cosine_min_lr_ratio` to test * lint
This commit is contained in:
@@ -572,7 +572,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
"alternate_lr_scheduler_type"
|
||||
|
||||
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
@@ -115,6 +116,17 @@ class SchedulerMixin(Trainer):
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
@@ -518,7 +518,7 @@ class HyperparametersConfig(BaseModel):
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"]]
|
||||
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
|
||||
@@ -6,6 +6,80 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
class RexLR(LRScheduler):
|
||||
"""
|
||||
Reflected Exponential (REX) learning rate scheduler.
|
||||
|
||||
- Original implementation: https://github.com/IvanVassi/REX_LR
|
||||
- Original license: Apache 2.0
|
||||
- Based on: https://arxiv.org/abs/2107.04197
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
|
||||
max_lr (float): The maximum learning rate.
|
||||
min_lr (float): The minimum learning rate.
|
||||
total_steps (int): The total number of training steps.
|
||||
num_warmup_steps (int): The number of warmup steps.
|
||||
last_step (int): The index of last step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
|
||||
):
|
||||
if min_lr > max_lr:
|
||||
raise ValueError(
|
||||
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
|
||||
)
|
||||
if num_warmup_steps > total_steps:
|
||||
raise ValueError(
|
||||
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
|
||||
)
|
||||
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.total_steps = total_steps
|
||||
self.num_warmup_steps = num_warmup_steps
|
||||
self.last_step = last_step - 1
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
@property
|
||||
def last_step(self):
|
||||
return self.last_epoch
|
||||
|
||||
@last_step.setter
|
||||
def last_step(self, value):
|
||||
self.last_epoch = value
|
||||
|
||||
def get_lr(self):
|
||||
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
|
||||
if 1 <= self.last_step <= self.num_warmup_steps:
|
||||
return [
|
||||
base_lr * self.last_step / self.num_warmup_steps
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
# Post-warmup phase: adjust step relative to the end of warmup.
|
||||
step_after = self.last_step - self.num_warmup_steps
|
||||
remaining_steps = self.total_steps - self.num_warmup_steps
|
||||
|
||||
# Avoid LR spiking
|
||||
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
mod_iter = step_after % remaining_steps
|
||||
z = (remaining_steps - mod_iter) / remaining_steps
|
||||
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
|
||||
z / (0.1 + 0.9 * z)
|
||||
)
|
||||
return [base_lr * rex_factor for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class InterpolatingLogScheduler(LRScheduler):
|
||||
"""
|
||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
|
||||
Reference in New Issue
Block a user