add axolotl trainer and quadratic warmup

This commit is contained in:
Wing Lian
2023-06-12 00:18:21 -04:00
parent 93dacba228
commit 7dc580b837
2 changed files with 94 additions and 4 deletions

View File

@@ -1,6 +1,9 @@
"""Module for custom LRScheduler class"""
import math
from functools import partial
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class InterpolatingLogScheduler(LRScheduler):
@@ -42,3 +45,58 @@ class InterpolatingLogScheduler(LRScheduler):
lrs = [self.max_lr for base_lr in self.base_lrs]
return lrs
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
def get_cosine_schedule_with_quadratic_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -17,10 +17,42 @@ from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import SavePeftModelCallback
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
class OneCycleLRSchedulerTrainer(Trainer):
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
"""
if self.lr_scheduler is None: # pylint: disable=access-member-before-definition
"""# type: ignore"""
if self.args.lr_scheduler_type == "cosine_with_quadratic":
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
@@ -259,7 +291,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else transformers.Trainer
else AxolotlTrainer
)
trainer = trainer_cls(
model=model,