From b3f5e00ff5ac240678b3e8554be86a6db197455d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 18 Aug 2023 20:28:23 -0400 Subject: [PATCH] use save_strategy from config if available (#434) * use save_strategy from config if available * update docs for save_strategy --- README.md | 1 + src/axolotl/utils/trainer.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c61899c13..ef28983f1 100644 --- a/README.md +++ b/README.md @@ -472,6 +472,7 @@ warmup_steps: 100 learning_rate: 0.00003 lr_quadratic_warmup: logging_steps: +save_strategy: # set to `no` to skip checkpoint saves save_steps: # leave empty to save at each epoch eval_steps: save_total_limit: # checkpoints saved at a time diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index fa0f0e384..af17684e2 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ # we have an eval set, but no steps defined, use epoch training_arguments_kwargs["evaluation_strategy"] = "epoch" + if cfg.save_strategy: + training_arguments_kwargs["save_strategy"] = cfg.save_strategy + else: + training_arguments_kwargs["save_strategy"] = ( + "steps" if cfg.save_steps else "epoch", + ) + training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1, max_seq_length=cfg.sequence_len, @@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs, learning_rate=cfg.learning_rate, - save_strategy="steps" if cfg.save_steps else "epoch", save_steps=cfg.save_steps, output_dir=cfg.output_dir, save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,