No gather single gpu (#523)
* don't attempt to gather on multi-gpu * also check distributed status in bench callback
This commit is contained in:
@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
gather_scalar_from_all_ranks,
|
||||
get_world_size,
|
||||
is_distributed,
|
||||
is_main_process,
|
||||
zero_first,
|
||||
)
|
||||
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
||||
lambda: len(data_loader), get_world_size()
|
||||
)
|
||||
|
||||
if not is_main_process():
|
||||
if is_distributed() and not is_main_process():
|
||||
dist.gather_object(local_bench_names, dst=0)
|
||||
else:
|
||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||
if is_distributed():
|
||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||
else:
|
||||
gathered_bench_names = [local_bench_names]
|
||||
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||
results = {f"{bench_split}_bench_loss": bench_loss}
|
||||
|
||||
|
||||
@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||
"""
|
||||
value_scalar = fn()
|
||||
if not is_distributed():
|
||||
return [value_scalar]
|
||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||
|
||||
if not is_main_process():
|
||||
|
||||
Reference in New Issue
Block a user