From 7dc580b837d1597912a3401080d0969f127ed9a2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 Jun 2023 00:18:21 -0400 Subject: [PATCH] add axolotl trainer and quadratic warmup --- src/axolotl/utils/schedulers.py | 60 ++++++++++++++++++++++++++++++++- src/axolotl/utils/trainer.py | 38 +++++++++++++++++++-- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index f9b9e3583..4c14a358a 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,6 +1,9 @@ """Module for custom LRScheduler class""" +import math +from functools import partial -from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler class InterpolatingLogScheduler(LRScheduler): @@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler): lrs = [self.max_lr for base_lr in self.base_lrs] return lrs + + +def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float +): + if current_step < num_warmup_steps: + return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + + +def get_cosine_schedule_with_quadratic_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_quadratic_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1250ad4f6..4881a4334 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,10 +17,42 @@ from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import SavePeftModelCallback -from axolotl.utils.schedulers import InterpolatingLogScheduler +from axolotl.utils.schedulers import ( + InterpolatingLogScheduler, + get_cosine_schedule_with_quadratic_warmup, +) -class OneCycleLRSchedulerTrainer(Trainer): +class AxolotlTrainer(Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + + if self.lr_scheduler is None: # pylint: disable=access-member-before-definition + """# type: ignore""" + if self.args.lr_scheduler_type == "cosine_with_quadratic": + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + return self.lr_scheduler + + +class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler """ @@ -259,7 +291,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_cls = ( OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") - else transformers.Trainer + else AxolotlTrainer ) trainer = trainer_cls( model=model,