diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 641db3fb5..7039089b8 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -22,7 +22,7 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.distributed import barrier, is_main_process, zero_first +from axolotl.utils.distributed import is_main_process, zero_first if TYPE_CHECKING: from axolotl.utils.trainer import AxolotlTrainingArguments @@ -203,53 +203,47 @@ def bench_eval_callback_factory(trainer, tokenizer): metrics: Dict[str, float], # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): - if is_main_process(): - data_loader = trainer.get_eval_dataloader(bench_dataset) - source_max_len = trainer.data_collator.max_length - trainer.data_collator.max_length = args.bench_source_max_len - trainer.model.eval() - preds, refs = [], [] - loss_bench = 0 - for batch in tqdm(data_loader, total=len(data_loader)): - (loss, logits, labels) = trainer.prediction_step( - trainer.model, - batch, - prediction_loss_only=False, - ) - # There are two tokens, the output, and eos token. - for i, logit in enumerate(logits): - label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] - logit_abcd = logit[label_non_zero_id - 1][abcd_idx] - preds.append(torch.argmax(logit_abcd).item()) - labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] - refs += [ - abcd_idx.index(label) if labels in abcd_idx else -1 - for label in labels.tolist() - ] - loss_bench += loss.item() - # Extract results by subject. - results = {"bench_loss": loss_bench / len(data_loader)} - subject = bench_dataset["subject"] - subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} - for s, p, r in zip( # pylint: disable=invalid-name - subject, preds, refs - ): - subjects[s]["preds"].append(p) - subjects[s]["refs"].append(r) - subject_scores = [] - for subject in subjects: - subject_score = accuracy.compute( - references=subjects[subject]["refs"], - predictions=subjects[subject]["preds"], - )["accuracy"] - if not pd.isna(subject_score): - results[ - f"bench_{bench_split}_accuracy_{subject}" - ] = subject_score - subject_scores.append(subject_score) - results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores) - trainer.log(results) - trainer.data_collator.max_length = source_max_len - barrier() + data_loader = trainer.get_eval_dataloader(bench_dataset) + source_max_len = trainer.data_collator.max_length + trainer.data_collator.max_length = args.bench_source_max_len + trainer.model.eval() + preds, refs = [], [] + loss_bench = 0 + for batch in tqdm(data_loader, total=len(data_loader)): + (loss, logits, labels) = trainer.prediction_step( + trainer.model, + batch, + prediction_loss_only=False, + ) + # There are two tokens, the output, and eos token. + for i, logit in enumerate(logits): + label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] + logit_abcd = logit[label_non_zero_id - 1][abcd_idx] + preds.append(torch.argmax(logit_abcd).item()) + labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] + refs += [ + abcd_idx.index(label) if labels in abcd_idx else -1 + for label in labels.tolist() + ] + loss_bench += loss.item() + # Extract results by subject. + results = {"bench_loss": loss_bench / len(data_loader)} + subject = bench_dataset["subject"] + subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)} + for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name + subjects[s]["preds"].append(p) + subjects[s]["refs"].append(r) + subject_scores = [] + for subject in subjects: + subject_score = accuracy.compute( + references=subjects[subject]["refs"], + predictions=subjects[subject]["preds"], + )["accuracy"] + if not pd.isna(subject_score): + results[f"bench_{bench_split}_accuracy_{subject}"] = subject_score + subject_scores.append(subject_score) + results[f"bench_{bench_split}_accuracy"] = np.mean(subject_scores) + trainer.log(results) + trainer.data_collator.max_length = source_max_len return BenchEvalCallback diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b4828c884..27dba92ef 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -147,7 +147,7 @@ class AxolotlTrainingArguments(TrainingArguments): }, ) bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for mmlu."} + default=2048, metadata={"help": "Maximum source sequence length for bench."} ) @@ -540,9 +540,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_ ) if cfg.do_bench_eval: - training_arguments_kwargs["do_mmlu_eval"] = cfg.do_bench_eval + training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval if cfg.bench_dataset: - training_arguments_kwargs["mmlu_dataset"] = cfg.bench_dataset + training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg max_steps=total_num_steps if cfg.max_steps else -1,