From 45848a92858a19a4287fa94849168a335aa630fb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Aug 2023 11:29:59 -0400 Subject: [PATCH] gather benchmarks from all ranks --- src/axolotl/utils/callbacks.py | 72 ++++++++++++++++++++++++-------- src/axolotl/utils/distributed.py | 41 ++++++++++++++++++ 2 files changed, 96 insertions(+), 17 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 1896df2de..e8c9eebfc 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -4,12 +4,13 @@ from __future__ import annotations import logging import os -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, List import evaluate import numpy as np import pandas as pd import torch +import torch.distributed as dist from datasets import load_dataset from optimum.bettertransformer import BetterTransformer from tqdm import tqdm @@ -22,7 +23,13 @@ from transformers import ( from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.distributed import is_main_process, zero_first +from axolotl.utils.distributed import ( + barrier, + gather_scalar_from_all_ranks, + get_world_size, + is_main_process, + zero_first, +) if TYPE_CHECKING: from axolotl.utils.trainer import AxolotlTrainingArguments @@ -193,7 +200,7 @@ def bench_eval_callback_factory(trainer, tokenizer): add_special_tokens=False, ) input_ids = tokenized_source["input_ids"] + tokenized_target["input_ids"] - labels = [-100] * len(tokenized_source["input_ids"]) + tokenized_target[ + labels = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + tokenized_target[ "input_ids" ] @@ -205,7 +212,7 @@ def bench_eval_callback_factory(trainer, tokenizer): with zero_first(is_main_process()): bench_dataset = bench_dataset.map(tokenize_evals) - bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) + bench_dataset = bench_dataset.filter(lambda x: x["labels"][-1] in abcd_idx) class BenchEvalCallback(TrainerCallback): """ @@ -234,7 +241,9 @@ def bench_eval_callback_factory(trainer, tokenizer): ) # There are two tokens, the output, and eos token. for i, logit in enumerate(logits): - label_non_zero_id = (batch["labels"][i] != -100).nonzero()[0][0] + label_non_zero_id = (batch["labels"][i] != IGNORE_INDEX).nonzero()[ + 0 + ][0] logit_abcd = logit[label_non_zero_id - 1][abcd_idx] preds.append(torch.argmax(logit_abcd).item()) labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] @@ -244,22 +253,51 @@ def bench_eval_callback_factory(trainer, tokenizer): ] loss_bench += loss.item() # Extract results by subject. - results = {"bench_loss": loss_bench / len(data_loader)} bench_name = bench_dataset["name"] bench_names: dict = {s: {"refs": [], "preds": []} for s in set(bench_name)} for s, p, r in zip(bench_name, preds, refs): # pylint: disable=invalid-name bench_names[s]["preds"].append(p) bench_names[s]["refs"].append(r) - bench_scores = [] - for bench_name in bench_names: - bench_score = accuracy.compute( - references=bench_names[bench_name]["refs"], - predictions=bench_names[bench_name]["preds"], - )["accuracy"] - if not pd.isna(bench_score): - results[f"bench_{bench_split}_accuracy_{bench_name}"] = bench_score - bench_scores.append(bench_score) - results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) - trainer.log(results) + barrier() + bench_loss = sum( + gather_scalar_from_all_ranks(lambda: loss_bench, get_world_size()) + ) / sum( + gather_scalar_from_all_ranks(lambda: len(data_loader), get_world_size()) + ) + results = {"bench_loss": bench_loss} + + local_bench_names = bench_names + gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] + # Gather results from all GPUs to GPU 0 + dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + + if is_main_process(): + # Combine results from all GPUs + combined_bench_names: Dict[str, Dict[str, List]] = {} + for bench_name in gathered_bench_names: + for name, data in bench_name.items(): + if name not in combined_bench_names: + combined_bench_names[name] = {"refs": [], "preds": []} + combined_bench_names[name]["refs"].extend(data["refs"]) + combined_bench_names[name]["preds"].extend(data["preds"]) + + bench_scores = [] + for ( + bench_name + ) in combined_bench_names: # pylint: disable=consider-using-dict-items + bench_score = accuracy.compute( + references=combined_bench_names[bench_name]["refs"], + predictions=combined_bench_names[bench_name]["preds"], + )["accuracy"] + if not pd.isna(bench_score): + results[ + f"bench_{bench_split}_accuracy_{bench_name}" + ] = bench_score + bench_scores.append(bench_score) + else: + results[f"bench_{bench_split}_accuracy_{bench_name}"] = 0.0 + bench_scores.append(0.0) + results[f"bench_{bench_split}_accuracy"] = np.mean(bench_scores) + trainer.log(results) return BenchEvalCallback diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index b3ea07c05..3c6653ee4 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -1,8 +1,10 @@ """ utility helpers for distributed checks """ +import os from contextlib import contextmanager +import torch import torch.distributed as dist from accelerate import Accelerator @@ -43,6 +45,10 @@ def is_main_process(): return dist.get_rank() == 0 +def get_world_size(): + return int(os.getenv("WORLD_SIZE", "1")) + + @contextmanager def zero_first(is_main): """ @@ -53,3 +59,38 @@ def zero_first(is_main): yield if is_main: # then rank 0 waits after it has run the context barrier() + + +def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name + """ + Run a callable 'fn' on all ranks and gather the results on the specified rank. + + Args: + - fn (callable): A function that computes the value. This should not have any side effects. + - rank (int, optional): The rank that gathers the values. Default is 0. + - world_size (int, optional): Total number of processes in the current distributed setup. + + Returns: + - A list of computed values from all ranks if on the gathering rank, otherwise None. + """ + value_scalar = fn() + value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() + + # Placeholder tensor for gathering results + if is_main_process(): + gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + else: + gathered_tensors = None + + dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) + + if is_main_process(): + # Convert tensors back to their original type (int or float) + gathered_values = [] + for tensor in gathered_tensors: + if tensor == tensor.int(): + gathered_values.append(int(tensor.item())) + else: + gathered_values.append(float(tensor.item())) + return gathered_values + return None