This commit is contained in:
Wing Lian
2023-08-25 21:49:04 -04:00
parent 99d844f215
commit f5db88a10d
2 changed files with 4 additions and 2 deletions

View File

@@ -217,7 +217,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
metrics: Dict[str, float], # pylint: disable=unused-argument metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument
): ):
data_loader = trainer.get_bench_dataloader(bench_dataset) data_loader = trainer.get_bench_dataloader(
bench_dataset.remove_columns(["input", "subject", "output"])
)
trainer.model.eval() trainer.model.eval()
preds, refs = [], [] preds, refs = [], []
loss_bench = 0 loss_bench = 0

View File

@@ -687,7 +687,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
), ),
bench_data_collat0r=transformers.DataCollatorForSeq2Seq( bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,