another callback fix for collator max len attribute

This commit is contained in:
Wing Lian
2023-08-19 21:36:24 -04:00
parent 6f166464d8
commit 943b84c490

View File

@@ -158,8 +158,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
**kwargs, # pylint: disable=unused-argument
):
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
source_max_len = trainer.data_collator.source_max_len
trainer.data_collator.source_max_len = args.mmlu_source_max_len
source_max_len = trainer.data_collator.max_length
source_max_len = args.max_seq_length
trainer.data_collator.max_length = args.mmlu_source_max_len
trainer.model.eval()
preds, refs = [], []
loss_mmlu = 0
@@ -194,6 +195,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
subject_scores.append(subject_score)
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
trainer.log(results)
trainer.data_collator.source_max_len = source_max_len
trainer.data_collator.max_length = source_max_len
return MMLUEvalCallback