more fixes
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.distributed import is_main_process, zero_first, zero_only
|
||||
from axolotl.utils.distributed import barrier, is_main_process, zero_first
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
@@ -119,6 +119,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("E", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||
]
|
||||
mmlu_split = "eval"
|
||||
if trainer.args.mmlu_dataset == "sampled":
|
||||
@@ -185,6 +188,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
||||
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||
|
||||
class BenchEvalCallback(TrainerCallback):
|
||||
"""
|
||||
@@ -199,7 +203,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
with zero_only(is_main_process()):
|
||||
if is_main_process():
|
||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||
source_max_len = trainer.data_collator.max_length
|
||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||
@@ -218,7 +222,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||
preds.append(torch.argmax(logit_abcd).item())
|
||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
||||
refs += [
|
||||
abcd_idx.index(label) if labels in abcd_idx else -1
|
||||
for label in labels.tolist()
|
||||
]
|
||||
loss_mmlu += loss.item()
|
||||
# Extract results by subject.
|
||||
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
||||
@@ -243,5 +250,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
||||
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||
trainer.log(results)
|
||||
trainer.data_collator.max_length = source_max_len
|
||||
barrier()
|
||||
|
||||
return BenchEvalCallback
|
||||
|
||||
@@ -53,16 +53,3 @@ def zero_first(is_main):
|
||||
yield
|
||||
if is_main: # then rank 0 waits after it has run the context
|
||||
barrier()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only(is_main):
|
||||
"""
|
||||
Context manager that ensures only the Rank 0 process executes the wrapped code.
|
||||
Other processes will simply bypass the code inside the context.
|
||||
All ranks will synchronize (wait) at the end before proceeding.
|
||||
"""
|
||||
if is_main:
|
||||
yield
|
||||
# All ranks will wait here until Rank 0 completes the code block.
|
||||
barrier()
|
||||
|
||||
Reference in New Issue
Block a user