This commit is contained in:
Wing Lian
2023-08-25 21:49:04 -04:00
parent 99d844f215
commit f5db88a10d
2 changed files with 4 additions and 2 deletions

View File

@@ -217,7 +217,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_bench_dataloader(bench_dataset)
data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output"])
)
trainer.model.eval()
preds, refs = [], []
loss_bench = 0

View File

@@ -687,7 +687,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collat0r=transformers.DataCollatorForSeq2Seq(
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,