another callback fix for collator max len attribute
This commit is contained in:
@@ -158,8 +158,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||||
source_max_len = trainer.data_collator.source_max_len
|
source_max_len = trainer.data_collator.max_length
|
||||||
trainer.data_collator.source_max_len = args.mmlu_source_max_len
|
source_max_len = args.max_seq_length
|
||||||
|
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||||
trainer.model.eval()
|
trainer.model.eval()
|
||||||
preds, refs = [], []
|
preds, refs = [], []
|
||||||
loss_mmlu = 0
|
loss_mmlu = 0
|
||||||
@@ -194,6 +195,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
subject_scores.append(subject_score)
|
subject_scores.append(subject_score)
|
||||||
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
trainer.data_collator.source_max_len = source_max_len
|
trainer.data_collator.max_length = source_max_len
|
||||||
|
|
||||||
return MMLUEvalCallback
|
return MMLUEvalCallback
|
||||||
|
|||||||
Reference in New Issue
Block a user