diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 2d1c7ceea..9891be713 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -22,7 +22,7 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.distributed import is_main_process, zero_first, zero_only +from axolotl.utils.distributed import barrier, is_main_process, zero_first if TYPE_CHECKING: from axolotl.utils.trainer import AxolotlTrainingArguments @@ -119,6 +119,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer): tokenizer("B", add_special_tokens=False).input_ids[0], tokenizer("C", add_special_tokens=False).input_ids[0], tokenizer("D", add_special_tokens=False).input_ids[0], + tokenizer("E", add_special_tokens=False).input_ids[0], + tokenizer("F", add_special_tokens=False).input_ids[0], + tokenizer("G", add_special_tokens=False).input_ids[0], ] mmlu_split = "eval" if trainer.args.mmlu_dataset == "sampled": @@ -185,6 +188,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): with zero_first(is_main_process()): mmlu_dataset = mmlu_dataset.map(tokenize_evals) + mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) class BenchEvalCallback(TrainerCallback): """ @@ -199,7 +203,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): metrics: Dict[str, float], # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - with zero_only(is_main_process()): + if is_main_process(): data_loader = trainer.get_eval_dataloader(mmlu_dataset) source_max_len = trainer.data_collator.max_length trainer.data_collator.max_length = args.mmlu_source_max_len @@ -218,7 +222,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer): logit_abcd = logit[label_non_zero_id - 1][abcd_idx] preds.append(torch.argmax(logit_abcd).item()) labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] - refs += [abcd_idx.index(label) for label in labels.tolist()] + refs += [ + abcd_idx.index(label) if labels in abcd_idx else -1 + for label in labels.tolist() + ] loss_mmlu += loss.item() # Extract results by subject. results = {"bench_loss": loss_mmlu / len(data_loader)} @@ -243,5 +250,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer): results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores) trainer.log(results) trainer.data_collator.max_length = source_max_len + barrier() return BenchEvalCallback diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 19b911717..b3ea07c05 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -53,16 +53,3 @@ def zero_first(is_main): yield if is_main: # then rank 0 waits after it has run the context barrier() - - -@contextmanager -def zero_only(is_main): - """ - Context manager that ensures only the Rank 0 process executes the wrapped code. - Other processes will simply bypass the code inside the context. - All ranks will synchronize (wait) at the end before proceeding. - """ - if is_main: - yield - # All ranks will wait here until Rank 0 completes the code block. - barrier()