diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index f9b9e3583..4c14a358a 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,6 +1,9 @@ """Module for custom LRScheduler class""" +import math +from functools import partial -from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler class InterpolatingLogScheduler(LRScheduler): @@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler): lrs = [self.max_lr for base_lr in self.base_lrs] return lrs + + +def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float +): + if current_step < num_warmup_steps: + return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + + +def get_cosine_schedule_with_quadratic_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_quadratic_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 263d6c78d..98ff9b3b9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import logging import math import os import sys +from dataclasses import field from pathlib import Path from typing import Optional @@ -13,17 +14,67 @@ import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, Trainer +from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, SavePeftModelCallback, ) -from axolotl.utils.schedulers import InterpolatingLogScheduler +from axolotl.utils.schedulers import ( + InterpolatingLogScheduler, + get_cosine_schedule_with_quadratic_warmup, +) -class OneCycleLRSchedulerTrainer(Trainer): +class AxolotlTrainingArguments(TrainingArguments): + """ + Extend the base TrainingArguments for axolotl helpers + """ + + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + + +class AxolotlTrainer(Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + args = None # type: AxolotlTrainingArguments + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ): + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + return self.lr_scheduler + + +class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler """ @@ -103,6 +154,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) + if cfg.lr_quadratic_warmup is not None: + training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup + # deepspeed if ( os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" @@ -128,7 +182,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id training_arguments_kwargs["push_to_hub"] = True - training_args = transformers.TrainingArguments( + training_args = AxolotlTrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None @@ -278,7 +332,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_cls = ( OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") - else transformers.Trainer + else AxolotlTrainer ) trainer = trainer_cls( model=model,