use save_strategy from config if available (#434)

* use save_strategy from config if available

* update docs for save_strategy
This commit is contained in:
Wing Lian
2023-08-18 20:28:23 -04:00
committed by GitHub
parent 5247c5004e
commit b3f5e00ff5
2 changed files with 8 additions and 1 deletions

View File

@@ -472,6 +472,7 @@ warmup_steps: 100
learning_rate: 0.00003
lr_quadratic_warmup:
logging_steps:
save_strategy: # set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch
eval_steps:
save_total_limit: # checkpoints saved at a time

View File

@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
# we have an eval set, but no steps defined, use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
training_arguments_kwargs["save_strategy"] = (
"steps" if cfg.save_steps else "epoch",
)
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
save_strategy="steps" if cfg.save_steps else "epoch",
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,