diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 57ddcb759..2e9280c03 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -21,6 +21,7 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage +from axolotl.utils.distributed import is_main_process, zero_first, zero_only if TYPE_CHECKING: from axolotl.utils.trainer import AxolotlTrainingArguments @@ -127,7 +128,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer): "test": "zero_shot_mmlu_test.json", }, ) - mmlu_dataset = mmlu_dataset.remove_columns("subject") + # mmlu_dataset = mmlu_dataset.remove_columns("subject") # MMLU Five-shot (Eval/Test only) elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]: mmlu_dataset = load_dataset( @@ -144,6 +145,36 @@ def mmlu_eval_callback_factory(trainer, tokenizer): if trainer.args.max_mmlu_samples is not None: mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples)) + def tokenize_evals(example): + source = f"{tokenizer.bos_token}{example['input']}" + target = f"{example['output']}{tokenizer.eos_token}" + + tokenized_source = tokenizer( + source, + max_length=2048, + truncation=True, + add_special_tokens=False, + ) + tokenized_target = tokenizer( + target, + max_length=2048, + truncation=True, + add_special_tokens=False, + ) + input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] + labels = [-100] * len(tokenized_source["input_ids"]) + tokenized_target[ + "input_ids" + ] + + return { + "input_ids": input_ids, + "labels": labels, + "subject": example["subject"], + } + + with zero_first(is_main_process()): + mmlu_dataset = mmlu_dataset.map(tokenize_evals) + class MMLUEvalCallback(TrainerCallback): """ TrainerCallback that runs the MMLU evals @@ -157,44 +188,46 @@ def mmlu_eval_callback_factory(trainer, tokenizer): metrics: Dict[str, float], # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - data_loader = trainer.get_eval_dataloader(mmlu_dataset) - source_max_len = trainer.data_collator.max_length - source_max_len = args.max_seq_length - trainer.data_collator.max_length = args.mmlu_source_max_len - trainer.model.eval() - preds, refs = [], [] - loss_mmlu = 0 - for batch in tqdm(data_loader, total=len(data_loader)): - (loss, logits, labels) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) - # There are two tokens, the output, and eos token. - for i, logit in enumerate(logits): - label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] - logit_abcd = logit[label_non_zero_id - 1][abcd_idx] - preds.append(torch.argmax(logit_abcd).item()) - labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] - refs += [abcd_idx.index(label) for label in labels.tolist()] - loss_mmlu += loss.item() - # Extract results by subject. - results = {"mmlu_loss": loss_mmlu / len(data_loader)} - subject = mmlu_dataset["subject"] - subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} - for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name - subjects[s]["preds"].append(p) - subjects[s]["refs"].append(r) - subject_scores = [] - for subject in subjects: - subject_score = accuracy.compute( - references=subjects[subject]["refs"], - predictions=subjects[subject]["preds"], - )["accuracy"] - results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score - subject_scores.append(subject_score) - results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) - trainer.log(results) - trainer.data_collator.max_length = source_max_len + with zero_only(is_main_process()): + data_loader = trainer.get_eval_dataloader(mmlu_dataset) + source_max_len = trainer.data_collator.max_length + trainer.data_collator.max_length = args.mmlu_source_max_len + trainer.model.eval() + preds, refs = [], [] + loss_mmlu = 0 + for batch in tqdm(data_loader, total=len(data_loader)): + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + # There are two tokens, the output, and eos token. + for i, logit in enumerate(logits): + label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] + logit_abcd = logit[label_non_zero_id - 1][abcd_idx] + preds.append(torch.argmax(logit_abcd).item()) + labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] + refs += [abcd_idx.index(label) for label in labels.tolist()] + loss_mmlu += loss.item() + # Extract results by subject. + results = {"mmlu_loss": loss_mmlu / len(data_loader)} + subject = mmlu_dataset["subject"] + subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} + for s, p, r in zip( # pylint: disable=invalid-name + subject, preds, refs + ): + subjects[s]["preds"].append(p) + subjects[s]["refs"].append(r) + subject_scores = [] + for subject in subjects: + subject_score = accuracy.compute( + references=subjects[subject]["refs"], + predictions=subjects[subject]["preds"], + )["accuracy"] + results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score + subject_scores.append(subject_score) + results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores) + trainer.log(results) + trainer.data_collator.max_length = source_max_len return MMLUEvalCallback diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index b3ea07c05..19b911717 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -53,3 +53,16 @@ def zero_first(is_main): yield if is_main: # then rank 0 waits after it has run the context barrier() + + +@contextmanager +def zero_only(is_main): + """ + Context manager that ensures only the Rank 0 process executes the wrapped code. + Other processes will simply bypass the code inside the context. + All ranks will synchronize (wait) at the end before proceeding. + """ + if is_main: + yield + # All ranks will wait here until Rank 0 completes the code block. + barrier()