gather benchmarks from all ranks
This commit is contained in:
@@ -4,12 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -22,7 +23,13 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import (
|
||||||
|
barrier,
|
||||||
|
gather_scalar_from_all_ranks,
|
||||||
|
get_world_size,
|
||||||
|
is_main_process,
|
||||||
|
zero_first,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||||
@@ -193,7 +200,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
)
|
)
|
||||||
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
|
||||||
labels = [-100] * len(tokenized_source["input_ids"]) + tokenized_target[
|
labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
|
||||||
"input_ids"
|
"input_ids"
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -205,7 +212,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
bench_dataset = bench_dataset.map(tokenize_evals)
|
bench_dataset = bench_dataset.map(tokenize_evals)
|
||||||
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx)
|
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-1] in abcd_idx)
|
||||||
|
|
||||||
class BenchEvalCallback(TrainerCallback):
|
class BenchEvalCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
@@ -234,7 +241,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
)
|
)
|
||||||
# There are two tokens, the output, and eos token.
|
# There are two tokens, the output, and eos token.
|
||||||
for i, logit in enumerate(logits):
|
for i, logit in enumerate(logits):
|
||||||
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0]
|
label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
|
||||||
|
0
|
||||||
|
][0]
|
||||||
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
|
||||||
preds.append(torch.argmax(logit_abcd).item())
|
preds.append(torch.argmax(logit_abcd).item())
|
||||||
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
|
||||||
@@ -244,22 +253,51 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
]
|
]
|
||||||
loss_bench += loss.item()
|
loss_bench += loss.item()
|
||||||
# Extract results by subject.
|
# Extract results by subject.
|
||||||
results = {"bench_loss": loss_bench / len(data_loader)}
|
|
||||||
bench_name = bench_dataset["name"]
|
bench_name = bench_dataset["name"]
|
||||||
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
|
||||||
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
|
||||||
bench_names[s]["preds"].append(p)
|
bench_names[s]["preds"].append(p)
|
||||||
bench_names[s]["refs"].append(r)
|
bench_names[s]["refs"].append(r)
|
||||||
bench_scores = []
|
barrier()
|
||||||
for bench_name in bench_names:
|
bench_loss = sum(
|
||||||
bench_score = accuracy.compute(
|
gather_scalar_from_all_ranks(lambda: loss_bench, get_world_size())
|
||||||
references=bench_names[bench_name]["refs"],
|
) / sum(
|
||||||
predictions=bench_names[bench_name]["preds"],
|
gather_scalar_from_all_ranks(lambda: len(data_loader), get_world_size())
|
||||||
)["accuracy"]
|
)
|
||||||
if not pd.isna(bench_score):
|
results = {"bench_loss": bench_loss}
|
||||||
results[f"bench_{bench_split}_accuracy_{bench_name}"] = bench_score
|
|
||||||
bench_scores.append(bench_score)
|
local_bench_names = bench_names
|
||||||
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
|
gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
|
||||||
trainer.log(results)
|
# Gather results from all GPUs to GPU 0
|
||||||
|
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
# Combine results from all GPUs
|
||||||
|
combined_bench_names: Dict[str, Dict[str, List]] = {}
|
||||||
|
for bench_name in gathered_bench_names:
|
||||||
|
for name, data in bench_name.items():
|
||||||
|
if name not in combined_bench_names:
|
||||||
|
combined_bench_names[name] = {"refs": [], "preds": []}
|
||||||
|
combined_bench_names[name]["refs"].extend(data["refs"])
|
||||||
|
combined_bench_names[name]["preds"].extend(data["preds"])
|
||||||
|
|
||||||
|
bench_scores = []
|
||||||
|
for (
|
||||||
|
bench_name
|
||||||
|
) in combined_bench_names: # pylint: disable=consider-using-dict-items
|
||||||
|
bench_score = accuracy.compute(
|
||||||
|
references=combined_bench_names[bench_name]["refs"],
|
||||||
|
predictions=combined_bench_names[bench_name]["preds"],
|
||||||
|
)["accuracy"]
|
||||||
|
if not pd.isna(bench_score):
|
||||||
|
results[
|
||||||
|
f"bench_{bench_split}_accuracy_{bench_name}"
|
||||||
|
] = bench_score
|
||||||
|
bench_scores.append(bench_score)
|
||||||
|
else:
|
||||||
|
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
|
||||||
|
bench_scores.append(0.0)
|
||||||
|
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
|
||||||
|
trainer.log(results)
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
|
||||||
@@ -43,6 +45,10 @@ def is_main_process():
|
|||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def zero_first(is_main):
|
def zero_first(is_main):
|
||||||
"""
|
"""
|
||||||
@@ -53,3 +59,38 @@ def zero_first(is_main):
|
|||||||
yield
|
yield
|
||||||
if is_main: # then rank 0 waits after it has run the context
|
if is_main: # then rank 0 waits after it has run the context
|
||||||
barrier()
|
barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||||
|
"""
|
||||||
|
Run a callable 'fn' on all ranks and gather the results on the specified rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fn (callable): A function that computes the value. This should not have any side effects.
|
||||||
|
- rank (int, optional): The rank that gathers the values. Default is 0.
|
||||||
|
- world_size (int, optional): Total number of processes in the current distributed setup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||||
|
"""
|
||||||
|
value_scalar = fn()
|
||||||
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||||
|
|
||||||
|
# Placeholder tensor for gathering results
|
||||||
|
if is_main_process():
|
||||||
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
|
||||||
|
else:
|
||||||
|
gathered_tensors = None
|
||||||
|
|
||||||
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
# Convert tensors back to their original type (int or float)
|
||||||
|
gathered_values = []
|
||||||
|
for tensor in gathered_tensors:
|
||||||
|
if tensor == tensor.int():
|
||||||
|
gathered_values.append(int(tensor.item()))
|
||||||
|
else:
|
||||||
|
gathered_values.append(float(tensor.item()))
|
||||||
|
return gathered_values
|
||||||
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user