diff --git a/README.md b/README.md index c61899c13..ef28983f1 100644 --- a/README.md +++ b/README.md @@ -472,6 +472,7 @@ warmup_steps: 100 learning_rate: 0.00003 lr_quadratic_warmup: logging_steps: +save_strategy: # set to `no` to skip checkpoint saves save_steps: # leave empty to save at each epoch eval_steps: save_total_limit: # checkpoints saved at a time diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fa0f0e384..af17684e2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ # we have an eval set, but no steps defined, use epoch training_arguments_kwargs["evaluation_strategy"] = "epoch" + if cfg.save_strategy: + training_arguments_kwargs["save_strategy"] = cfg.save_strategy + else: + training_arguments_kwargs["save_strategy"] = ( + "steps" if cfg.save_steps else "epoch", + ) + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, @@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, - save_strategy="steps" if cfg.save_steps else "epoch", save_steps=cfg.save_steps, output_dir=cfg.output_dir, save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,