diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ee5acfd55..8fc5a918b 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -27,6 +27,7 @@ from axolotl.utils.distributed import ( barrier, gather_scalar_from_all_ranks, get_world_size, + is_distributed, is_main_process, zero_first, ) @@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer): lambda: len(data_loader), get_world_size() ) - if not is_main_process(): + if is_distributed() and not is_main_process(): dist.gather_object(local_bench_names, dst=0) else: - dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + if is_distributed(): + dist.gather_object(local_bench_names, gathered_bench_names, dst=0) + else: + gathered_bench_names = [local_bench_names] bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks) results = {f"{bench_split}_bench_loss": bench_loss} diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index 38d0d1e05..5e527f3b9 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n - A list of computed values from all ranks if on the gathering rank, otherwise None. """ value_scalar = fn() + if not is_distributed(): + return [value_scalar] value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float() if not is_main_process():