From 943b84c4906ab3014339c93f4fc2c6a6b5925f10 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 19 Aug 2023 21:36:24 -0400 Subject: [PATCH] another callback fix for collator max len attribute --- src/axolotl/utils/callbacks.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 89411676f..57ddcb759 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -158,8 +158,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer): **kwargs, # pylint: disable=unused-argument ): data_loader = trainer.get_eval_dataloader(mmlu_dataset) - source_max_len = trainer.data_collator.source_max_len - trainer.data_collator.source_max_len = args.mmlu_source_max_len + source_max_len = trainer.data_collator.max_length + source_max_len = args.max_seq_length + trainer.data_collator.max_length = args.mmlu_source_max_len trainer.model.eval() preds, refs = [], [] loss_mmlu = 0 @@ -194,6 +195,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer): subject_scores.append(subject_score) results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) trainer.log(results) - trainer.data_collator.source_max_len = source_max_len + trainer.data_collator.max_length = source_max_len return MMLUEvalCallback