fix elif and add better messaging
This commit is contained in:
@@ -129,7 +129,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||||
if trainer.args.mmlu_dataset == "mmlu-zs":
|
elif trainer.args.mmlu_dataset == "mmlu-zs":
|
||||||
mmlu_dataset = load_dataset(
|
mmlu_dataset = load_dataset(
|
||||||
"openaccess-ai-collective/mmlu-evals",
|
"openaccess-ai-collective/mmlu-evals",
|
||||||
data_files={
|
data_files={
|
||||||
@@ -149,7 +149,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
)
|
)
|
||||||
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
|
# mmlu_dataset = mmlu_dataset.remove_columns('subject')
|
||||||
else:
|
else:
|
||||||
raise ValueError("unhandled value for mmlu_dataset training args")
|
raise ValueError(
|
||||||
|
f"unhandled value `{trainer.args.mmlu_dataset}` for mmlu_dataset training args"
|
||||||
|
)
|
||||||
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
|
mmlu_dataset = mmlu_dataset[trainer.args.mmlu_split]
|
||||||
if trainer.args.max_mmlu_samples is not None:
|
if trainer.args.max_mmlu_samples is not None:
|
||||||
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
||||||
|
|||||||
Reference in New Issue
Block a user