fix eval steps and strategy (#403)
This commit is contained in:
@@ -452,13 +452,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if cfg.val_set_size == 0:
|
if cfg.val_set_size == 0:
|
||||||
evaluation_strategy = "no"
|
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||||
elif cfg.eval_steps < 1:
|
elif cfg.eval_steps:
|
||||||
# eval every epoch
|
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||||
evaluation_strategy = "epoch"
|
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
# eval every eval_steps steps
|
# we have an eval set, but no steps defined, use epoch
|
||||||
evaluation_strategy = "steps"
|
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||||
@@ -471,9 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
evaluation_strategy=evaluation_strategy,
|
|
||||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||||
|
|||||||
Reference in New Issue
Block a user