use save_strategy from config if available (#434)
* use save_strategy from config if available * update docs for save_strategy
This commit is contained in:
@@ -472,6 +472,7 @@ warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps:
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
|
||||
@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
# we have an eval set, but no steps defined, use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
||||
else:
|
||||
training_arguments_kwargs["save_strategy"] = (
|
||||
"steps" if cfg.save_steps else "epoch",
|
||||
)
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
max_seq_length=cfg.sequence_len,
|
||||
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||
save_steps=cfg.save_steps,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
|
||||
Reference in New Issue
Block a user