fixes
This commit is contained in:
@@ -217,7 +217,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_bench_dataloader(bench_dataset)
|
||||
data_loader = trainer.get_bench_dataloader(
|
||||
bench_dataset.remove_columns(["input", "subject", "output"])
|
||||
)
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
loss_bench = 0
|
||||
|
||||
@@ -687,7 +687,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collat0r=transformers.DataCollatorForSeq2Seq(
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user