more fixes
This commit is contained in:
@@ -22,7 +22,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first, zero_only
|
from axolotl.utils.distributed import barrier, is_main_process, zero_first
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
@@ -119,6 +119,9 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
tokenizer("B", add_special_tokens=False).input_ids[0],
|
tokenizer("B", add_special_tokens=False).input_ids[0],
|
||||||
tokenizer("C", add_special_tokens=False).input_ids[0],
|
tokenizer("C", add_special_tokens=False).input_ids[0],
|
||||||
tokenizer("D", add_special_tokens=False).input_ids[0],
|
tokenizer("D", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("E", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("F", add_special_tokens=False).input_ids[0],
|
||||||
|
tokenizer("G", add_special_tokens=False).input_ids[0],
|
||||||
]
|
]
|
||||||
mmlu_split = "eval"
|
mmlu_split = "eval"
|
||||||
if trainer.args.mmlu_dataset == "sampled":
|
if trainer.args.mmlu_dataset == "sampled":
|
||||||
@@ -185,6 +188,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
mmlu_dataset = mmlu_dataset.map(tokenize_evals)
|
||||||
|
mmlu_dataset = mmlu_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
||||||
|
|
||||||
class BenchEvalCallback(TrainerCallback):
|
class BenchEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
@@ -199,7 +203,7 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
metrics: Dict[str, float], # pylint: disable=unused-argument
|
metrics: Dict[str, float], # pylint: disable=unused-argument
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
with zero_only(is_main_process()):
|
if is_main_process():
|
||||||
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
data_loader = trainer.get_eval_dataloader(mmlu_dataset)
|
||||||
source_max_len = trainer.data_collator.max_length
|
source_max_len = trainer.data_collator.max_length
|
||||||
trainer.data_collator.max_length = args.mmlu_source_max_len
|
trainer.data_collator.max_length = args.mmlu_source_max_len
|
||||||
@@ -218,7 +222,10 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
preds.append(torch.argmax(logit_abcd).item())
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
refs += [abcd_idx.index(label) for label in labels.tolist()]
|
refs += [
|
||||||
|
abcd_idx.index(label) if labels in abcd_idx else -1
|
||||||
|
for label in labels.tolist()
|
||||||
|
]
|
||||||
loss_mmlu += loss.item()
|
loss_mmlu += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
results = {"bench_loss": loss_mmlu / len(data_loader)}
|
||||||
@@ -243,5 +250,6 @@ def mmlu_eval_callback_factory(trainer, tokenizer):
|
|||||||
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
results[f"bench_{mmlu_split}_accuracy"] = np.mean(subject_scores)
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
trainer.data_collator.max_length = source_max_len
|
trainer.data_collator.max_length = source_max_len
|
||||||
|
barrier()
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -53,16 +53,3 @@ def zero_first(is_main):
|
|||||||
yield
|
yield
|
||||||
if is_main: # then rank 0 waits after it has run the context
|
if is_main: # then rank 0 waits after it has run the context
|
||||||
barrier()
|
barrier()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def zero_only(is_main):
|
|
||||||
"""
|
|
||||||
Context manager that ensures only the Rank 0 process executes the wrapped code.
|
|
||||||
Other processes will simply bypass the code inside the context.
|
|
||||||
All ranks will synchronize (wait) at the end before proceeding.
|
|
||||||
"""
|
|
||||||
if is_main:
|
|
||||||
yield
|
|
||||||
# All ranks will wait here until Rank 0 completes the code block.
|
|
||||||
barrier()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user