Lint trainer.py
This commit is contained in:
@@ -1,3 +1,5 @@
|
|||||||
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
|
|||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(Trainer):
|
class OneCycleLRSchedulerTrainer(Trainer):
|
||||||
|
"""
|
||||||
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
):
|
):
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
num_training_steps = num_training_steps
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
self.lr_scheduler = OneCycleLR(
|
||||||
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
training_arguments_kwargs["bf16_full_eval"] = True
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||||
training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
|
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
||||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||||
if cfg.gradient_checkpointing is not None:
|
if cfg.gradient_checkpointing:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
from alpaca_lora_4bit.gradient_checkpointing import (
|
from alpaca_lora_4bit.gradient_checkpointing import (
|
||||||
apply_gradient_checkpointing,
|
apply_gradient_checkpointing,
|
||||||
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
save_steps=save_steps,
|
save_steps=save_steps,
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=3,
|
save_total_limit=3,
|
||||||
load_best_model_at_end=True
|
load_best_model_at_end=(
|
||||||
if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
|
cfg.val_set_size > 0
|
||||||
and cfg.val_set_size > 0
|
and save_steps
|
||||||
and save_steps is not None
|
and save_steps % eval_steps == 0
|
||||||
and save_steps % eval_steps == 0
|
and cfg.load_in_8bit is not True
|
||||||
and cfg.load_in_8bit is not True
|
)
|
||||||
else False,
|
or False,
|
||||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||||
group_by_length=cfg.group_by_length,
|
group_by_length=cfg.group_by_length,
|
||||||
report_to="wandb" if cfg.use_wandb else None,
|
report_to="wandb" if cfg.use_wandb else None,
|
||||||
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if (
|
if (
|
||||||
cfg.optimizer == "adamw_bnb_8bit"
|
cfg.optimizer == "adamw_bnb_8bit"
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
and not "deepspeed" in training_arguments_kwargs
|
and "deepspeed" not in training_arguments_kwargs
|
||||||
and not cfg.fsdp
|
and not cfg.fsdp
|
||||||
):
|
):
|
||||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||||
|
|||||||
Reference in New Issue
Block a user