Lint trainer.py

This commit is contained in:
NanoCode012
2023-05-29 13:46:49 +09:00
parent 1a2bd7ff62
commit ddb86ea821

View File

@@ -1,3 +1,5 @@
"""Module containing the Trainer class and related functions"""
import importlib import importlib
import math import math
import os import os
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
class OneCycleLRSchedulerTrainer(Trainer): class OneCycleLRSchedulerTrainer(Trainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
): ):
optimizer = self.optimizer if optimizer is None else optimizer optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps) num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
num_training_steps = num_training_steps
pct_start = num_warmup_steps / num_training_steps pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR( self.lr_scheduler = OneCycleLR(
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["bf16_full_eval"] = True training_arguments_kwargs["bf16_full_eval"] = True
else: else:
training_arguments_kwargs["bf16"] = cfg.bf16 training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None: if cfg.gradient_checkpointing:
if cfg.gptq: if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import ( from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing, apply_gradient_checkpointing,
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
save_steps=save_steps, save_steps=save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=3, save_total_limit=3,
load_best_model_at_end=True load_best_model_at_end=(
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False cfg.val_set_size > 0
and cfg.val_set_size > 0 and save_steps
and save_steps is not None and save_steps % eval_steps == 0
and save_steps % eval_steps == 0 and cfg.load_in_8bit is not True
and cfg.load_in_8bit is not True )
else False, or False,
ddp_find_unused_parameters=False if cfg.ddp else None, ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length, group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None, report_to="wandb" if cfg.use_wandb else None,
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if ( if (
cfg.optimizer == "adamw_bnb_8bit" cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq and not cfg.gptq
and not "deepspeed" in training_arguments_kwargs and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp and not cfg.fsdp
): ):
decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = get_parameter_names(model, [nn.LayerNorm])