diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 89411676f..57ddcb759 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -158,8 +158,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer): **kwargs, # pylint: disable=unused-argument ): data_loader = trainer.get_eval_dataloader(mmlu_dataset) - source_max_len = trainer.data_collator.source_max_len - trainer.data_collator.source_max_len = args.mmlu_source_max_len + source_max_len = trainer.data_collator.max_length + source_max_len = args.max_seq_length + trainer.data_collator.max_length = args.mmlu_source_max_len trainer.model.eval() preds, refs = [], [] loss_mmlu = 0 @@ -194,6 +195,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer): subject_scores.append(subject_score) results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) trainer.log(results) - trainer.data_collator.source_max_len = source_max_len + trainer.data_collator.max_length = source_max_len return MMLUEvalCallback