No gather single gpu (#523)

* don't attempt to gather on multi-gpu

* also check distributed status in bench callback
This commit is contained in:
Wing Lian
2023-09-03 23:24:28 -04:00
committed by GitHub
parent 1991946c5a
commit 09f154397e
2 changed files with 8 additions and 2 deletions

View File

@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
barrier,
gather_scalar_from_all_ranks,
get_world_size,
is_distributed,
is_main_process,
zero_first,
)
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size()
)
if not is_main_process():
if is_distributed() and not is_main_process():
dist.gather_object(local_bench_names, dst=0)
else:
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
if is_distributed():
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
else:
gathered_bench_names = [local_bench_names]
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
results = {f"{bench_split}_bench_loss": bench_loss}

View File

@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
if not is_main_process():