use save_strategy from config if available (#434)
* use save_strategy from config if available * update docs for save_strategy
This commit is contained in:
@@ -472,6 +472,7 @@ warmup_steps: 100
|
|||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
|
save_strategy: # set to `no` to skip checkpoint saves
|
||||||
save_steps: # leave empty to save at each epoch
|
save_steps: # leave empty to save at each epoch
|
||||||
eval_steps:
|
eval_steps:
|
||||||
save_total_limit: # checkpoints saved at a time
|
save_total_limit: # checkpoints saved at a time
|
||||||
|
|||||||
@@ -457,6 +457,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
# we have an eval set, but no steps defined, use epoch
|
# we have an eval set, but no steps defined, use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if cfg.save_strategy:
|
||||||
|
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
||||||
|
else:
|
||||||
|
training_arguments_kwargs["save_strategy"] = (
|
||||||
|
"steps" if cfg.save_steps else "epoch",
|
||||||
|
)
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
@@ -468,7 +475,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||||
|
|||||||
Reference in New Issue
Block a user