better configuration for quadratic warmup

This commit is contained in:
Wing Lian
2023-07-10 11:52:59 -04:00
parent 7dc580b837
commit c49729d2bc

View File

@@ -5,6 +5,7 @@ import logging
import math import math
import os import os
import sys import sys
from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@@ -13,7 +14,7 @@ import torch.cuda
import transformers import transformers
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback from axolotl.utils.callbacks import SavePeftModelCallback
@@ -23,11 +24,24 @@ from axolotl.utils.schedulers import (
) )
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
class AxolotlTrainer(Trainer): class AxolotlTrainer(Trainer):
""" """
Extend the base Trainer for axolotl helpers Extend the base Trainer for axolotl helpers
""" """
args = None # type: AxolotlTrainingArguments
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
): ):
@@ -37,11 +51,16 @@ class AxolotlTrainer(Trainer):
Args: Args:
num_training_steps (int): The number of training steps to do. num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
""" """
if self.lr_scheduler is None: # pylint: disable=access-member-before-definition # fmt: off
"""# type: ignore""" if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
if self.args.lr_scheduler_type == "cosine_with_quadratic": # fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer, optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
@@ -132,6 +151,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.fsdp_config: if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed # deepspeed
if ( if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
@@ -144,7 +166,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
# TODO search Path("./") for one # TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json" training_arguments_kwargs["deepspeed"] = "./ds_config.json"
training_args = transformers.TrainingArguments( training_args = AxolotlTrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None if cfg.eval_batch_size is not None