From 7dc580b837d1597912a3401080d0969f127ed9a2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 12 Jun 2023 00:18:21 -0400 Subject: [PATCH 1/2] add axolotl trainer and quadratic warmup --- src/axolotl/utils/schedulers.py | 60 ++++++++++++++++++++++++++++++++- src/axolotl/utils/trainer.py | 38 +++++++++++++++++++-- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/schedulers.py b/src/axolotl/utils/schedulers.py index f9b9e3583..4c14a358a 100644 --- a/src/axolotl/utils/schedulers.py +++ b/src/axolotl/utils/schedulers.py @@ -1,6 +1,9 @@ """Module for custom LRScheduler class""" +import math +from functools import partial -from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR, LRScheduler class InterpolatingLogScheduler(LRScheduler): @@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler): lrs = [self.max_lr for base_lr in self.base_lrs] return lrs + + +def _get_cosine_schedule_with_quadratic_warmup_lr_lambda( + current_step: int, + *, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float +): + if current_step < num_warmup_steps: + return (float(current_step) / float(max(1, num_warmup_steps))) ** 2 + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) + ) + + +def get_cosine_schedule_with_quadratic_warmup( + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, +): + """ + Create a schedule with a learning rate that decreases following the values of the cosine function between the + initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the + initial lr set in the optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + num_warmup_steps (`int`): + The number of steps for the warmup phase. + num_training_steps (`int`): + The total number of training steps. + num_cycles (`float`, *optional*, defaults to 0.5): + The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 + following a half-cosine). + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + lr_lambda = partial( + _get_cosine_schedule_with_quadratic_warmup_lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1250ad4f6..4881a4334 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -17,10 +17,42 @@ from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import SavePeftModelCallback -from axolotl.utils.schedulers import InterpolatingLogScheduler +from axolotl.utils.schedulers import ( + InterpolatingLogScheduler, + get_cosine_schedule_with_quadratic_warmup, +) -class OneCycleLRSchedulerTrainer(Trainer): +class AxolotlTrainer(Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + + if self.lr_scheduler is None: # pylint: disable=access-member-before-definition + """# type: ignore""" + if self.args.lr_scheduler_type == "cosine_with_quadratic": + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + else: + return super().create_scheduler(num_training_steps, optimizer) + return self.lr_scheduler + + +class OneCycleLRSchedulerTrainer(AxolotlTrainer): """ Trainer subclass that uses the OneCycleLR scheduler """ @@ -259,7 +291,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): trainer_cls = ( OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") - else transformers.Trainer + else AxolotlTrainer ) trainer = trainer_cls( model=model, From c49729d2bc863c4d32f8a5e9cf81274a21d4df21 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 10 Jul 2023 11:52:59 -0400 Subject: [PATCH 2/2] better configuration for quadratic warmup --- src/axolotl/utils/trainer.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4881a4334..d231bd0ef 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import logging import math import os import sys +from dataclasses import field from pathlib import Path from typing import Optional @@ -13,7 +14,7 @@ import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, Trainer +from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import SavePeftModelCallback @@ -23,11 +24,24 @@ from axolotl.utils.schedulers import ( ) +class AxolotlTrainingArguments(TrainingArguments): + """ + Extend the base TrainingArguments for axolotl helpers + """ + + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers """ + args = None # type: AxolotlTrainingArguments + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): @@ -37,11 +51,16 @@ class AxolotlTrainer(Trainer): Args: num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer """ - if self.lr_scheduler is None: # pylint: disable=access-member-before-definition - """# type: ignore""" - if self.args.lr_scheduler_type == "cosine_with_quadratic": + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ): self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), @@ -132,6 +151,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) + if cfg.lr_quadratic_warmup is not None: + training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup + # deepspeed if ( os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" @@ -144,7 +166,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): # TODO search Path("./") for one training_arguments_kwargs["deepspeed"] = "./ds_config.json" - training_args = transformers.TrainingArguments( + training_args = AxolotlTrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None