fix mmlu evals
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.distributed import is_main_process, zero_first, zero_only
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
@@ -127,7 +128,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
"test": "zero_shot_mmlu_test.json",
|
"test": "zero_shot_mmlu_test.json",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
# mmlu_dataset = mmlu_dataset.remove_columns("subject")
|
||||||
# MMLU Five-shot (Eval/Test only)
|
# MMLU Five-shot (Eval/Test only)
|
||||||
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
|
elif trainer.args.mmlu_dataset in ["mmlu", "mmlu-fs"]:
|
||||||
mmlu_dataset = load_dataset(
|
mmlu_dataset = load_dataset(
|
||||||
@@ -144,6 +145,36 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
if trainer.args.max_mmlu_samples is not None:
|
if trainer.args.max_mmlu_samples is not None:
|
||||||
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
mmlu_dataset = mmlu_dataset.select(range(trainer.args.max_mmlu_samples))
|
||||||
|
|
||||||
|
def tokenize_evals(example):
|
||||||
|
source = f"{tokenizer.bos_token}{example['input']}"
|
||||||
|
target = f"{example['output']}{tokenizer.eos_token}"
|
||||||
|
|
||||||
|
tokenized_source = tokenizer(
|
||||||
|
source,
|
||||||
|
max_length=2048,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
tokenized_target = tokenizer(
|
||||||
|
target,
|
||||||
|
max_length=2048,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
)
|
||||||
|
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
||||||
|
labels = [-100] * len(tokenized_source["input_ids"]) + tokenized_target[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
"subject": example["subject"],
|
||||||
|
}
|
||||||
|
|
||||||
|
with zero_first(is_main_process()):
|
||||||
|
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
||||||
|
|
||||||
class MMLUEvalCallback(TrainerCallback):
|
class MMLUEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
TrainerCallback that runs the MMLU evals
|
TrainerCallback that runs the MMLU evals
|
||||||
@@ -157,44 +188,46 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
with zero_only(is_main_process()):
|
||||||
source_max_len = trainer.data_collator.max_length
|
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||||
source_max_len = args.max_seq_length
|
source_max_len = trainer.data_collator.max_length
|
||||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||||
trainer.model.eval()
|
trainer.model.eval()
|
||||||
preds, refs = [], []
|
preds, refs = [], []
|
||||||
loss_mmlu = 0
|
loss_mmlu = 0
|
||||||
for batch in tqdm(data_loader, total=len(data_loader)):
|
for batch in tqdm(data_loader, total=len(data_loader)):
|
||||||
(loss, logits, labels) = trainer.prediction_step(
|
(loss, logits, labels) = trainer.prediction_step(
|
||||||
trainer.model,
|
trainer.model,
|
||||||
batch,
|
batch,
|
||||||
prediction_loss_only=False,
|
prediction_loss_only=False,
|
||||||
)
|
)
|
||||||
# There are two tokens, the output, and eos token.
|
# There are two tokens, the output, and eos token.
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
||||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
preds.append(torch.argmax(logit_abcd).item())
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
||||||
loss_mmlu += loss.item()
|
loss_mmlu += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"mmlu_loss": loss_mmlu / len(data_loader)}
|
results = {"mmlu_loss": loss_mmlu / len(data_loader)}
|
||||||
subject = mmlu_dataset["subject"]
|
subject = mmlu_dataset["subject"]
|
||||||
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
subjects: dict = {s: {"refs": [], "preds": []} for s in set(subject)}
|
||||||
for s, p, r in zip(subject, preds, refs): # pylint: disable=invalid-name
|
for s, p, r in zip( # pylint: disable=invalid-name
|
||||||
subjects[s]["preds"].append(p)
|
subject, preds, refs
|
||||||
subjects[s]["refs"].append(r)
|
):
|
||||||
subject_scores = []
|
subjects[s]["preds"].append(p)
|
||||||
for subject in subjects:
|
subjects[s]["refs"].append(r)
|
||||||
subject_score = accuracy.compute(
|
subject_scores = []
|
||||||
references=subjects[subject]["refs"],
|
for subject in subjects:
|
||||||
predictions=subjects[subject]["preds"],
|
subject_score = accuracy.compute(
|
||||||
)["accuracy"]
|
references=subjects[subject]["refs"],
|
||||||
results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score
|
predictions=subjects[subject]["preds"],
|
||||||
subject_scores.append(subject_score)
|
)["accuracy"]
|
||||||
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
results[f"mmlu_{mmlu_split}_accuracy_{subject}"] = subject_score
|
||||||
trainer.log(results)
|
subject_scores.append(subject_score)
|
||||||
trainer.data_collator.max_length = source_max_len
|
results[f"mmlu_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||||
|
trainer.log(results)
|
||||||
|
trainer.data_collator.max_length = source_max_len
|
||||||
|
|
||||||
return MMLUEvalCallback
|
return MMLUEvalCallback
|
||||||
|
|||||||
@@ -53,3 +53,16 @@ def zero_first(is_main):
|
|||||||
yield
|
yield
|
||||||
if is_main: # then rank 0 waits after it has run the context
|
if is_main: # then rank 0 waits after it has run the context
|
||||||
barrier()
|
barrier()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def zero_only(is_main):
|
||||||
|
"""
|
||||||
|
Context manager that ensures only the Rank 0 process executes the wrapped code.
|
||||||
|
Other processes will simply bypass the code inside the context.
|
||||||
|
All ranks will synchronize (wait) at the end before proceeding.
|
||||||
|
"""
|
||||||
|
if is_main:
|
||||||
|
yield
|
||||||
|
# All ranks will wait here until Rank 0 completes the code block.
|
||||||
|
barrier()
|
||||||
|
|||||||
Reference in New Issue
Block a user