From c3de28942c84b49684c2608591af65a33b32fa7a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 29 Aug 2023 06:57:28 -0700 Subject: [PATCH] fix for gather across multiple gpus --- src/axolotl/utils/callbacks.py | 27 ++++++++++++++++----------- src/axolotl/utils/distributed.py | 11 ++++------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index e8c9eebfc..92333f4ca 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -212,7 +212,7 @@ def bench_eval_callback_factory(trainer, tokenizer): with zero_first(is_main_process()): bench_dataset = bench_dataset.map(tokenize_evals) - bench_dataset = bench_dataset.filter(lambda x: x["labels"][-1] in abcd_idx) + bench_dataset = bench_dataset.filter(lambda x: x["labels"][-2] in abcd_idx) class BenchEvalCallback(TrainerCallback): """ @@ -248,7 +248,7 @@ def bench_eval_callback_factory(trainer, tokenizer): preds.append(torch.argmax(logit_abcd).item()) labels = labels[labels != IGNORE_INDEX].view(-1, 2)[:, 0] refs += [ - abcd_idx.index(label) if labels in abcd_idx else -1 + abcd_idx.index(label) if label in abcd_idx else -1 for label in labels.tolist() ] loss_bench += loss.item() @@ -259,19 +259,24 @@ def bench_eval_callback_factory(trainer, tokenizer): bench_names[s]["preds"].append(p) bench_names[s]["refs"].append(r) barrier() - bench_loss = sum( - gather_scalar_from_all_ranks(lambda: loss_bench, get_world_size()) - ) / sum( - gather_scalar_from_all_ranks(lambda: len(data_loader), get_world_size()) - ) - results = {"bench_loss": bench_loss} - local_bench_names = bench_names gathered_bench_names: List[Dict] = [{} for _ in range(get_world_size())] # Gather results from all GPUs to GPU 0 - dist.gather_object(local_bench_names, gathered_bench_names, dst=0) - if is_main_process(): + loss_bench_ranks = gather_scalar_from_all_ranks( + lambda: loss_bench, get_world_size() + ) + len_data_loader_ranks = gather_scalar_from_all_ranks( + lambda: len(data_loader), get_world_size() + ) + + if not is_main_process(): + dist.gather_object(local_bench_names, dst=0) + else: + dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) + results = {"bench_loss": bench_loss} + # Combine results from all GPUs combined_bench_names: Dict[str, Dict[str, List]] = {} for bench_name in gathered_bench_names: diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 3c6653ee4..38d0d1e05 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -76,15 +76,12 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n value_scalar = fn() value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() - # Placeholder tensor for gathering results - if is_main_process(): - gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + if not is_main_process(): + dist.gather(value_tensor, dst=0) else: - gathered_tensors = None + gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] + dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) - dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) - - if is_main_process(): # Convert tensors back to their original type (int or float) gathered_values = [] for tensor in gathered_tensors: