From c49729d2bc863c4d32f8a5e9cf81274a21d4df21 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 10 Jul 2023 11:52:59 -0400 Subject: [PATCH] better configuration for quadratic warmup --- src/axolotl/utils/trainer.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 4881a4334..d231bd0ef 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -5,6 +5,7 @@ import logging import math import os import sys +from dataclasses import field from pathlib import Path from typing import Optional @@ -13,7 +14,7 @@ import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback, Trainer +from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.callbacks import SavePeftModelCallback @@ -23,11 +24,24 @@ from axolotl.utils.schedulers import ( ) +class AxolotlTrainingArguments(TrainingArguments): + """ + Extend the base TrainingArguments for axolotl helpers + """ + + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + + class AxolotlTrainer(Trainer): """ Extend the base Trainer for axolotl helpers """ + args = None # type: AxolotlTrainingArguments + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): @@ -37,11 +51,16 @@ class AxolotlTrainer(Trainer): Args: num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer """ - if self.lr_scheduler is None: # pylint: disable=access-member-before-definition - """# type: ignore""" - if self.args.lr_scheduler_type == "cosine_with_quadratic": + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ): self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init optimizer, num_warmup_steps=self.args.get_warmup_steps(num_training_steps), @@ -132,6 +151,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) + if cfg.lr_quadratic_warmup is not None: + training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup + # deepspeed if ( os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" @@ -144,7 +166,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): # TODO search Path("./") for one training_arguments_kwargs["deepspeed"] = "./ds_config.json" - training_args = transformers.TrainingArguments( + training_args = AxolotlTrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None