diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 12fe93fe4..5569a38ae 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -85,7 +85,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_args = transformers.TrainingArguments( per_device_train_batch_size=cfg.micro_batch_size, - per_device_eval_batch_size=cfg.eval_batch_size, + per_device_eval_batch_size=cfg.eval_batch_size if cfg.eval_batch_size is not None else cfg.micro_batch_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps, num_train_epochs=cfg.num_epochs,