diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 4f633cd9e..a2cb3a62e 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -210,7 +210,7 @@ def bench_eval_callback_factory(trainer, tokenizer): "subject": example["subject"], } - with dist_state.main_process_first: + with dist_state.main_process_first(): bench_dataset = bench_dataset.map(tokenize_evals) bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)