Lint trainer.py

This commit is contained in:
NanoCode012
2023-05-29 13:46:49 +09:00
parent 1a2bd7ff62
commit ddb86ea821

View File

@@ -1,3 +1,5 @@
"""Module containing the Trainer class and related functions"""
import importlib
import math
import os
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
class OneCycleLRSchedulerTrainer(Trainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
num_training_steps = num_training_steps
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.gradient_checkpointing is not None:
if cfg.gradient_checkpointing:
if cfg.gptq:
from alpaca_lora_4bit.gradient_checkpointing import (
apply_gradient_checkpointing,
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
and cfg.val_set_size > 0
and save_steps is not None
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
else False,
load_best_model_at_end=(
cfg.val_set_size > 0
and save_steps
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
and not "deepspeed" in training_arguments_kwargs
and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])