use micro batch size for eval size if not specified

This commit is contained in:
Wing Lian
2023-05-07 18:26:05 -04:00
parent fae36c7111
commit 550502b321

View File

@@ -85,7 +85,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,