diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 6ffd207ca..3fb90f1e1 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -121,20 +121,20 @@ def mmlu_eval_callback_factory(trainer, tokenizer): mmlu_split = "eval" if trainer.args.mmlu_dataset == "mmlu-zs": mmlu_dataset = load_dataset( - "json", + "openaccess-ai-collective/mmlu-evals", data_files={ - "eval": "data/mmlu/zero_shot_mmlu_val.json", - "test": "data/mmlu/zero_shot_mmlu_test.json", + "eval": "zero_shot_mmlu_val.json", + "test": "zero_shot_mmlu_test.json", }, ) mmlu_dataset = mmlu_dataset.remove_columns("subject") # MMLU Five-shot (Eval/Test only) elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]: mmlu_dataset = load_dataset( - "json", + "openaccess-ai-collective/mmlu-evals", data_files={ - "eval": "data/mmlu/five_shot_mmlu_val.json", - "test": "data/mmlu/five_shot_mmlu_test.json", + "eval": "five_shot_mmlu_val.json", + "test": "five_shot_mmlu_test.json", }, ) # mmlu_dataset = mmlu_dataset.remove_columns('subject')