gather benchmarks from all ranks

This commit is contained in:
Wing Lian
2023-08-28 11:29:59 -04:00
parent d6cea18034
commit 45848a9285
2 changed files with 96 additions and 17 deletions

View File

@@ -4,12 +4,13 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict, List
import evaluate import evaluate
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
@@ -22,7 +23,13 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_main_process,
zero_first,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments from axolotl.utils.trainer import AxolotlTrainingArguments
@@ -193,7 +200,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
add_special_tokens=False, add_special_tokens=False,
) )
input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"]
labels = [-100] * len(tokenized_source["input_ids"]) + tokenized_target[ labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[
"input_ids" "input_ids"
] ]
@@ -205,7 +212,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
with zero_first(is_main_process()): with zero_first(is_main_process()):
bench_dataset = bench_dataset.map(tokenize_evals) bench_dataset = bench_dataset.map(tokenize_evals)
bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) bench_dataset = bench_dataset.filter(lambda x: x["labels"][-1] in abcd_idx)
class BenchEvalCallback(TrainerCallback): class BenchEvalCallback(TrainerCallback):
""" """
@@ -234,7 +241,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
) )
# There are two tokens, the output, and eos token. # There are two tokens, the output, and eos token.
for i, logit in enumerate(logits): for i, logit in enumerate(logits):
label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[
0
][0]
logit_abcd = logit[label_non_zero_id - 1][abcd_idx] logit_abcd = logit[label_non_zero_id - 1][abcd_idx]
preds.append(torch.argmax(logit_abcd).item()) preds.append(torch.argmax(logit_abcd).item())
labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0]
@@ -244,22 +253,51 @@ def bench_eval_callback_factory(trainer, tokenizer):
] ]
loss_bench += loss.item() loss_bench += loss.item()
# Extract results by subject. # Extract results by subject.
results = {"bench_loss": loss_bench / len(data_loader)}
bench_name = bench_dataset["name"] bench_name = bench_dataset["name"]
bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)}
for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name
bench_names[s]["preds"].append(p) bench_names[s]["preds"].append(p)
bench_names[s]["refs"].append(r) bench_names[s]["refs"].append(r)
bench_scores = [] barrier()
for bench_name in bench_names: bench_loss = sum(
bench_score = accuracy.compute( gather_scalar_from_all_ranks(lambda: loss_bench, get_world_size())
references=bench_names[bench_name]["refs"], ) / sum(
predictions=bench_names[bench_name]["preds"], gather_scalar_from_all_ranks(lambda: len(data_loader), get_world_size())
)["accuracy"] )
if not pd.isna(bench_score): results = {"bench_loss": bench_loss}
results[f"bench_{bench_split}_accuracy_{bench_name}"] = bench_score
bench_scores.append(bench_score) local_bench_names = bench_names
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())]
trainer.log(results) # Gather results from all GPUs to GPU 0
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
if is_main_process():
# Combine results from all GPUs
combined_bench_names: Dict[str, Dict[str, List]] = {}
for bench_name in gathered_bench_names:
for name, data in bench_name.items():
if name not in combined_bench_names:
combined_bench_names[name] = {"refs": [], "preds": []}
combined_bench_names[name]["refs"].extend(data["refs"])
combined_bench_names[name]["preds"].extend(data["preds"])
bench_scores = []
for (
bench_name
) in combined_bench_names: # pylint: disable=consider-using-dict-items
bench_score = accuracy.compute(
references=combined_bench_names[bench_name]["refs"],
predictions=combined_bench_names[bench_name]["preds"],
)["accuracy"]
if not pd.isna(bench_score):
results[
f"bench_{bench_split}_accuracy_{bench_name}"
] = bench_score
bench_scores.append(bench_score)
else:
results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0
bench_scores.append(0.0)
results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores)
trainer.log(results)
return BenchEvalCallback return BenchEvalCallback

View File

@@ -1,8 +1,10 @@
""" """
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os
from contextlib import contextmanager from contextlib import contextmanager
import torch
import torch.distributed as dist import torch.distributed as dist
from accelerate import Accelerator from accelerate import Accelerator
@@ -43,6 +45,10 @@ def is_main_process():
return dist.get_rank() == 0 return dist.get_rank() == 0
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main):
""" """
@@ -53,3 +59,38 @@ def zero_first(is_main):
yield yield
if is_main: # then rank 0 waits after it has run the context if is_main: # then rank 0 waits after it has run the context
barrier() barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
else:
gathered_tensors = None
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
if is_main_process():
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None