Add REX LR Scheduler (#2380)
* Update trainer_builder.py * Update base.py * Update __init__.py * Update base.py * Update base.py * Update config.qmd * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * lint * lint * lint * lint * lint * lint * Update base.py * Update base.py * lint * Update base.py * Update base.py * Move RexLR to `schedulers.py` * Remove RexLR from `base.py` * Fix tooltip formatting * lint * Create test_schedulers.py * Use a default optimizer in test * lint * lint * Add `warmup_steps` and `cosine_min_lr_ratio` to test * lint
This commit is contained in:
@@ -451,7 +451,7 @@ gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
|
||||
@@ -572,7 +572,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
"alternate_lr_scheduler_type"
|
||||
|
||||
@@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
@@ -115,6 +116,17 @@ class SchedulerMixin(Trainer):
|
||||
**extra_lr_kwargs,
|
||||
**self.args.lr_scheduler_kwargs,
|
||||
)
|
||||
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||
if use_cosine_min_lr:
|
||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
@@ -518,7 +518,7 @@ class HyperparametersConfig(BaseModel):
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
lr_scheduler: Optional[
|
||||
Union[SchedulerType, Literal["one_cycle"]]
|
||||
Union[SchedulerType, Literal["one_cycle"], Literal["rex"]]
|
||||
] = SchedulerType.COSINE
|
||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||
lr_quadratic_warmup: Optional[bool] = None
|
||||
|
||||
@@ -6,6 +6,80 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
|
||||
class RexLR(LRScheduler):
|
||||
"""
|
||||
Reflected Exponential (REX) learning rate scheduler.
|
||||
|
||||
- Original implementation: https://github.com/IvanVassi/REX_LR
|
||||
- Original license: Apache 2.0
|
||||
- Based on: https://arxiv.org/abs/2107.04197
|
||||
|
||||
Args:
|
||||
optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for.
|
||||
max_lr (float): The maximum learning rate.
|
||||
min_lr (float): The minimum learning rate.
|
||||
total_steps (int): The total number of training steps.
|
||||
num_warmup_steps (int): The number of warmup steps.
|
||||
last_step (int): The index of last step.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0
|
||||
):
|
||||
if min_lr > max_lr:
|
||||
raise ValueError(
|
||||
f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}'
|
||||
)
|
||||
if num_warmup_steps > total_steps:
|
||||
raise ValueError(
|
||||
f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})."
|
||||
)
|
||||
|
||||
self.min_lr = min_lr
|
||||
self.max_lr = max_lr
|
||||
self.total_steps = total_steps
|
||||
self.num_warmup_steps = num_warmup_steps
|
||||
self.last_step = last_step - 1
|
||||
|
||||
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
|
||||
for group in optimizer.param_groups:
|
||||
group.setdefault("initial_lr", group["lr"])
|
||||
|
||||
# Pass self.last_step as last_epoch to the parent.
|
||||
super().__init__(optimizer, last_epoch=self.last_step)
|
||||
|
||||
@property
|
||||
def last_step(self):
|
||||
return self.last_epoch
|
||||
|
||||
@last_step.setter
|
||||
def last_step(self, value):
|
||||
self.last_epoch = value
|
||||
|
||||
def get_lr(self):
|
||||
# Warmup phase: if defined, increase lr linearly from 0 to max_lr.
|
||||
if 1 <= self.last_step <= self.num_warmup_steps:
|
||||
return [
|
||||
base_lr * self.last_step / self.num_warmup_steps
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
# Post-warmup phase: adjust step relative to the end of warmup.
|
||||
step_after = self.last_step - self.num_warmup_steps
|
||||
remaining_steps = self.total_steps - self.num_warmup_steps
|
||||
|
||||
# Avoid LR spiking
|
||||
if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
mod_iter = step_after % remaining_steps
|
||||
z = (remaining_steps - mod_iter) / remaining_steps
|
||||
rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * (
|
||||
z / (0.1 + 0.9 * z)
|
||||
)
|
||||
return [base_lr * rex_factor for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class InterpolatingLogScheduler(LRScheduler):
|
||||
"""
|
||||
A scheduler that interpolates learning rates in a logarithmic fashion
|
||||
|
||||
71
tests/e2e/test_schedulers.py
Normal file
71
tests/e2e/test_schedulers.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
E2E tests for custom schedulers using Llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestCustomSchedulers(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_rex_scheduler(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_hf",
|
||||
"max_steps": 20,
|
||||
"lr_scheduler": "rex",
|
||||
"warmup_steps": 5,
|
||||
"cosine_min_lr_ratio": 0.05,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
Reference in New Issue
Block a user