another callback fix for collator max len attribute
This commit is contained in:
@@ -158,8 +158,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||
source_max_len = trainer.data_collator.source_max_len
|
||||
trainer.data_collator.source_max_len = args.mmlu_source_max_len
|
||||
source_max_len = trainer.data_collator.max_length
|
||||
source_max_len = args.max_seq_length
|
||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||
trainer.model.eval()
|
||||
preds, refs = [], []
|
||||
loss_mmlu = 0
|
||||
@@ -194,6 +195,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
subject_scores.append(subject_score)
|
||||
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||
trainer.log(results)
|
||||
trainer.data_collator.source_max_len = source_max_len
|
||||
trainer.data_collator.max_length = source_max_len
|
||||
|
||||
return MMLUEvalCallback
|
||||
|
||||
Reference in New Issue
Block a user