Add REX LR Scheduler (#2380)

* Update trainer_builder.py

* Update base.py

* Update __init__.py

* Update base.py

* Update base.py

* Update config.qmd

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* Update base.py

* lint

* lint

* lint

* lint

* lint

* lint

* Update base.py

* Update base.py

* lint

* Update base.py

* Update base.py

* Move RexLR to `schedulers.py`

* Remove RexLR from `base.py`

* Fix tooltip formatting

* lint

* Create test_schedulers.py

* Use a default optimizer in test

* lint

* lint

* Add `warmup_steps` and `cosine_min_lr_ratio` to test

* lint
This commit is contained in:
xzuyn
2025-03-05 10:26:11 -05:00
committed by GitHub
parent d4de93a7bb
commit 0134093acc
6 changed files with 160 additions and 3 deletions

View File

@@ -451,7 +451,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)

View File

@@ -572,7 +572,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"

View File

@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
@@ -115,6 +116,17 @@ class SchedulerMixin(Trainer):
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif self.args.alternate_lr_scheduler_type == "rex":
if use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")

View File

@@ -518,7 +518,7 @@ class HyperparametersConfig(BaseModel):
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[
Union[SchedulerType, Literal["one_cycle"]]
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
] = SchedulerType.COSINE
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None

View File

@@ -6,6 +6,80 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class RexLR(LRScheduler):
"""
Reflected Exponential (REX) learning rate scheduler.
- Original implementation: https://github.com/IvanVassi/REX_LR
- Original license: Apache 2.0
- Based on: https://arxiv.org/abs/2107.04197
Args:
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
max_lr (float): The maximum learning rate.
min_lr (float): The minimum learning rate.
total_steps (int): The total number of training steps.
num_warmup_steps (int): The number of warmup steps.
last_step (int): The index of last step.
"""
def __init__(
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
):
if min_lr > max_lr:
raise ValueError(
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
)
if num_warmup_steps > total_steps:
raise ValueError(
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
)
self.min_lr = min_lr
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
# Pass self.last_step as last_epoch to the parent.
super().__init__(optimizer, last_epoch=self.last_step)
@property
def last_step(self):
return self.last_epoch
@last_step.setter
def last_step(self, value):
self.last_epoch = value
def get_lr(self):
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
if 1 <= self.last_step <= self.num_warmup_steps:
return [
base_lr * self.last_step / self.num_warmup_steps
for base_lr in self.base_lrs
]
# Post-warmup phase: adjust step relative to the end of warmup.
step_after = self.last_step - self.num_warmup_steps
remaining_steps = self.total_steps - self.num_warmup_steps
# Avoid LR spiking
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
return [self.min_lr for _ in self.base_lrs]
mod_iter = step_after % remaining_steps
z = (remaining_steps - mod_iter) / remaining_steps
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
z / (0.1 + 0.9 * z)
)
return [base_lr * rex_factor for base_lr in self.base_lrs]
class InterpolatingLogScheduler(LRScheduler):
"""
A scheduler that interpolates learning rates in a logarithmic fashion

View File

@@ -0,0 +1,71 @@
"""
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_hf",
"max_steps": 20,
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)