diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 97b02baba..299e39664 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,3 +1,5 @@ +"""Module containing the Trainer class and related functions""" + import importlib import math import os @@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback class OneCycleLRSchedulerTrainer(Trainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + def create_scheduler( self, num_training_steps: int, optimizer: torch.optim.Optimizer = None ): optimizer = self.optimizer if optimizer is None else optimizer num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - num_training_steps = num_training_steps pct_start = num_warmup_steps / num_training_steps self.lr_scheduler = OneCycleLR( @@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["bf16_full_eval"] = True else: training_arguments_kwargs["bf16"] = cfg.bf16 - training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False + training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False training_arguments_kwargs["tf32"] = cfg.tf32 training_arguments_kwargs["warmup_steps"] = warmup_steps training_arguments_kwargs["logging_steps"] = logging_steps - if cfg.gradient_checkpointing is not None: + if cfg.gradient_checkpointing: if cfg.gptq: from alpaca_lora_4bit.gradient_checkpointing import ( apply_gradient_checkpointing, @@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): save_steps=save_steps, output_dir=cfg.output_dir, save_total_limit=3, - load_best_model_at_end=True - if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False - and cfg.val_set_size > 0 - and save_steps is not None - and save_steps % eval_steps == 0 - and cfg.load_in_8bit is not True - else False, + load_best_model_at_end=( + cfg.val_set_size > 0 + and save_steps + and save_steps % eval_steps == 0 + and cfg.load_in_8bit is not True + ) + or False, ddp_find_unused_parameters=False if cfg.ddp else None, group_by_length=cfg.group_by_length, report_to="wandb" if cfg.use_wandb else None, @@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if ( cfg.optimizer == "adamw_bnb_8bit" and not cfg.gptq - and not "deepspeed" in training_arguments_kwargs + and "deepspeed" not in training_arguments_kwargs and not cfg.fsdp ): decay_parameters = get_parameter_names(model, [nn.LayerNorm])