fixes
This commit is contained in:
@@ -217,7 +217,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
data_loader = trainer.get_bench_dataloader(bench_dataset)
|
data_loader = trainer.get_bench_dataloader(
|
||||||
|
bench_dataset.remove_columns(["input", "subject", "output"])
|
||||||
|
)
|
||||||
trainer.model.eval()
|
trainer.model.eval()
|
||||||
preds, refs = [], []
|
preds, refs = [], []
|
||||||
loss_bench = 0
|
loss_bench = 0
|
||||||
|
|||||||
@@ -687,7 +687,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
),
|
),
|
||||||
bench_data_collat0r=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user