diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 18896d5cd..1aa0168ca 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -217,7 +217,9 @@ def bench_eval_callback_factory(trainer, tokenizer): metrics: Dict[str, float], # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - data_loader = trainer.get_bench_dataloader(bench_dataset) + data_loader = trainer.get_bench_dataloader( + bench_dataset.remove_columns(["input", "subject", "output"]) + ) trainer.model.eval() preds, refs = [], [] loss_bench = 0 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a7933dc35..a9f4dec26 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -687,7 +687,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ return_tensors="pt", **data_collator_kwargs, ), - bench_data_collat0r=transformers.DataCollatorForSeq2Seq( + bench_data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", **data_collator_kwargs,