No gather single gpu (#523)
* don't attempt to gather on multi-gpu * also check distributed status in bench callback
This commit is contained in:
@@ -27,6 +27,7 @@ from axolotl.utils.distributed import (
|
|||||||
barrier,
|
barrier,
|
||||||
gather_scalar_from_all_ranks,
|
gather_scalar_from_all_ranks,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
|
is_distributed,
|
||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
@@ -270,10 +271,13 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
lambda: len(data_loader), get_world_size()
|
lambda: len(data_loader), get_world_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_main_process():
|
if is_distributed() and not is_main_process():
|
||||||
dist.gather_object(local_bench_names, dst=0)
|
dist.gather_object(local_bench_names, dst=0)
|
||||||
else:
|
else:
|
||||||
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
if is_distributed():
|
||||||
|
dist.gather_object(local_bench_names, gathered_bench_names, dst=0)
|
||||||
|
else:
|
||||||
|
gathered_bench_names = [local_bench_names]
|
||||||
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
bench_loss = sum(loss_bench_ranks) / sum(len_data_loader_ranks)
|
||||||
results = {f"{bench_split}_bench_loss": bench_loss}
|
results = {f"{bench_split}_bench_loss": bench_loss}
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,8 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||||
"""
|
"""
|
||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
|
if not is_distributed():
|
||||||
|
return [value_scalar]
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||||
|
|
||||||
if not is_main_process():
|
if not is_main_process():
|
||||||
|
|||||||
Reference in New Issue
Block a user