From 0134093acc6b78f5a5cde0570472f114a2c432de Mon Sep 17 00:00:00 2001 From: xzuyn <16216325+xzuyn@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:26:11 -0500 Subject: [PATCH] Add REX LR Scheduler (#2380) * Update trainer_builder.py * Update base.py * Update __init__.py * Update base.py * Update base.py * Update config.qmd * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * Update base.py * lint * lint * lint * lint * lint * lint * Update base.py * Update base.py * lint * Update base.py * Update base.py * Move RexLR to `schedulers.py` * Remove RexLR from `base.py` * Fix tooltip formatting * lint * Create test_schedulers.py * Use a default optimizer in test * lint * lint * Add `warmup_steps` and `cosine_min_lr_ratio` to test * lint --- docs/config.qmd | 2 +- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/core/trainers/base.py | 12 +++ .../config/models/input/v0_4_1/__init__.py | 2 +- src/axolotl/utils/schedulers.py | 74 +++++++++++++++++++ tests/e2e/test_schedulers.py | 71 ++++++++++++++++++ 6 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 tests/e2e/test_schedulers.py diff --git a/docs/config.qmd b/docs/config.qmd index 1f3a50b2e..cfd137ff0 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -451,7 +451,7 @@ gradient_checkpointing: false early_stopping_patience: 3 # Specify a scheduler and kwargs to use with the optimizer -lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine +lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine lr_scheduler_kwargs: cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index d4ddc9bf3..fe9c8bcae 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -572,7 +572,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups - if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: + if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs[ "alternate_lr_scheduler_type" diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index ee2545b21..27f00f1fd 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -25,6 +25,7 @@ from trl.trainer.utils import pad_to_length from axolotl.monkeypatch.relora import ReLoRAScheduler from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( + RexLR, get_cosine_schedule_with_min_lr, get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_warmup_decay_constant, @@ -115,6 +116,17 @@ class SchedulerMixin(Trainer): **extra_lr_kwargs, **self.args.lr_scheduler_kwargs, ) + elif self.args.alternate_lr_scheduler_type == "rex": + if use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + + self.lr_scheduler = RexLR( + optimizer=optimizer, + max_lr=self.args.learning_rate, + min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), + total_steps=num_training_steps, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + ) elif use_cosine_quadratic: if use_cosine_min_lr: LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index c7803b8cc..180e02823 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -518,7 +518,7 @@ class HyperparametersConfig(BaseModel): ) torchdistx_path: Optional[str] = None lr_scheduler: Optional[ - Union[SchedulerType, Literal["one_cycle"]] + Union[SchedulerType, Literal["one_cycle"], Literal["rex"]] ] = SchedulerType.COSINE lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_quadratic_warmup: Optional[bool] = None diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index 94387e5ab..6f057fbd9 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -6,6 +6,80 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler +class RexLR(LRScheduler): + """ + Reflected Exponential (REX) learning rate scheduler. + + - Original implementation: https://github.com/IvanVassi/REX_LR + - Original license: Apache 2.0 + - Based on: https://arxiv.org/abs/2107.04197 + + Args: + optimizer (torch.optim.Optimizer): The optimizer to schedule the learning rate for. + max_lr (float): The maximum learning rate. + min_lr (float): The minimum learning rate. + total_steps (int): The total number of training steps. + num_warmup_steps (int): The number of warmup steps. + last_step (int): The index of last step. + """ + + def __init__( + self, optimizer, max_lr, min_lr, total_steps=0, num_warmup_steps=0, last_step=0 + ): + if min_lr > max_lr: + raise ValueError( + f'Value of "min_lr" should be less than value of "max_lr". Got min_lr={min_lr} and max_lr={max_lr}' + ) + if num_warmup_steps > total_steps: + raise ValueError( + f"num_warmup_steps ({num_warmup_steps}) must be less than or equal to total_steps ({total_steps})." + ) + + self.min_lr = min_lr + self.max_lr = max_lr + self.total_steps = total_steps + self.num_warmup_steps = num_warmup_steps + self.last_step = last_step - 1 + + # Ensure each parameter group has an "initial_lr" key to avoid issues when resuming. + for group in optimizer.param_groups: + group.setdefault("initial_lr", group["lr"]) + + # Pass self.last_step as last_epoch to the parent. + super().__init__(optimizer, last_epoch=self.last_step) + + @property + def last_step(self): + return self.last_epoch + + @last_step.setter + def last_step(self, value): + self.last_epoch = value + + def get_lr(self): + # Warmup phase: if defined, increase lr linearly from 0 to max_lr. + if 1 <= self.last_step <= self.num_warmup_steps: + return [ + base_lr * self.last_step / self.num_warmup_steps + for base_lr in self.base_lrs + ] + + # Post-warmup phase: adjust step relative to the end of warmup. + step_after = self.last_step - self.num_warmup_steps + remaining_steps = self.total_steps - self.num_warmup_steps + + # Avoid LR spiking + if step_after >= remaining_steps or step_after == -1 or remaining_steps <= 0: + return [self.min_lr for _ in self.base_lrs] + + mod_iter = step_after % remaining_steps + z = (remaining_steps - mod_iter) / remaining_steps + rex_factor = self.min_lr / self.max_lr + (1.0 - self.min_lr / self.max_lr) * ( + z / (0.1 + 0.9 * z) + ) + return [base_lr * rex_factor for base_lr in self.base_lrs] + + class InterpolatingLogScheduler(LRScheduler): """ A scheduler that interpolates learning rates in a logarithmic fashion diff --git a/tests/e2e/test_schedulers.py b/tests/e2e/test_schedulers.py new file mode 100644 index 000000000..c492fdccf --- /dev/null +++ b/tests/e2e/test_schedulers.py @@ -0,0 +1,71 @@ +""" +E2E tests for custom schedulers using Llama +""" + +import logging +import os +import unittest + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestCustomSchedulers(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_rex_scheduler(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_hf", + "max_steps": 20, + "lr_scheduler": "rex", + "warmup_steps": 5, + "cosine_min_lr_ratio": 0.05, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg)