diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index f69c56117..4c3c3fccd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -62,8 +62,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.logging_steps is not None else max(min(int(0.005 * total_num_steps), 10), 1) ) - save_steps = cfg.save_steps - eval_steps = cfg.eval_steps training_arguments_kwargs = {} if cfg.bf16 == "full": @@ -123,16 +121,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", - save_strategy="steps" if save_steps else "epoch", - eval_steps=eval_steps if cfg.val_set_size > 0 else None, - save_steps=save_steps, + save_strategy="steps" if cfg.save_steps else "epoch", + eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, + save_steps=cfg.save_steps, output_dir=cfg.output_dir, save_total_limit=3, load_best_model_at_end=( cfg.load_best_model_at_end is not False and cfg.val_set_size > 0 - and save_steps - and save_steps % eval_steps == 0 + and cfg.save_steps + and cfg.save_steps % cfg.eval_steps == 0 and cfg.load_in_8bit is not True ) or False,