From 09f154397eeed6fd86d887c2b9bdd0f49c885630 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 3 Sep 2023 23:24:28 -0400 Subject: [PATCH] No gather single gpu (#523) * don't attempt to gather on multi-gpu * also check distributed status in bench callback --- src/axolotl/utils/callbacks.py | 8 ++++++-- src/axolotl/utils/distributed.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ee5acfd55..8fc5a918b 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -27,6 +27,7 @@ from axolotl.utils.distributed import ( barrier, gather_scalar_from_all_ranks, get_world_size, + is_distributed, is_main_process, zero_first, ) @@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer): lambda: len(data_loader), get_world_size() ) - if not is_main_process(): + if is_distributed() and not is_main_process(): dist.gather_object(local_bench_names, dst=0) else: - dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + if is_distributed(): + dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + else: + gathered_bench_names = [local_bench_names] bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) results = {f"{bench_split}_bench_loss": bench_loss} diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 38d0d1e05..5e527f3b9 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() + if not is_distributed(): + return [value_scalar] value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() if not is_main_process():