more fixes

This commit is contained in:
Wing Lian
2023-08-21 04:31:15 -04:00
parent d4c8b66f3d
commit ef062d8fcb
2 changed files with 11 additions and 16 deletions

View File

@@ -22,7 +22,7 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import is_main_process, zero_first, zero_only
from axolotl.utils.distributed import barrier, is_main_process, zero_first
if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments
@@ -119,6 +119,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
tokenizer("B", add_special_tokens=False).input_ids[0],
tokenizer("C", add_special_tokens=False).input_ids[0],
tokenizer("D", add_special_tokens=False).input_ids[0],
tokenizer("E", add_special_tokens=False).input_ids[0],
tokenizer("F", add_special_tokens=False).input_ids[0],
tokenizer("G", add_special_tokens=False).input_ids[0],
]
mmlu_split = "eval"
if trainer.args.mmlu_dataset == "sampled":
@@ -185,6 +188,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
with zero_first(is_main_process()):
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
class BenchEvalCallback(TrainerCallback):
"""
@@ -199,7 +203,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
metrics: Dict[str, float], # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
with zero_only(is_main_process()):
if is_main_process():
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
source_max_len = trainer.data_collator.max_length
trainer.data_collator.max_length = args.mmlu_source_max_len
@@ -218,7 +222,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
refs += [abcd_idx.index(label) for label in labels.tolist()]
refs += [
abcd_idx.index(label) if labels in abcd_idx else -1
for label in labels.tolist()
]
loss_mmlu += loss.item()
# Extract results by subject.
results = {"bench_loss": loss_mmlu / len(data_loader)}
@@ -243,5 +250,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
trainer.log(results)
trainer.data_collator.max_length = source_max_len
barrier()
return BenchEvalCallback

View File

@@ -53,16 +53,3 @@ def zero_first(is_main):
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
@contextmanager
def zero_only(is_main):
"""
Context manager that ensures only the Rank 0 process executes the wrapped code.
Other processes will simply bypass the code inside the context.
All ranks will synchronize (wait) at the end before proceeding.
"""
if is_main:
yield
# All ranks will wait here until Rank 0 completes the code block.
barrier()