diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 4d100a4be..2d1c7ceea 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -129,7 +129,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): }, ) # mmlu_dataset = mmlu_dataset.remove_columns("subject") - if trainer.args.mmlu_dataset == "mmlu-zs": + elif trainer.args.mmlu_dataset == "mmlu-zs": mmlu_dataset = load_dataset( "openaccess-ai-collective/mmlu-evals", data_files={ @@ -149,7 +149,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer): ) # mmlu_dataset = mmlu_dataset.remove_columns('subject') else: - raise ValueError("unhandled value for mmlu_dataset training args") + raise ValueError( + f"unhandled value `{trainer.args.mmlu_dataset}` for mmlu_dataset training args" + ) mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split] if trainer.args.max_mmlu_samples is not None: mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))