fix distributed devices (#612)

* fix distributed devices

* Update distributed.py

* Update distributed.py
This commit is contained in:
Maxime
2023-09-21 15:11:34 +02:00
committed by GitHub
parent c1382e79b6
commit 2fe95cdcc1

View File

@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
else:
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
) # Placeholder tensor
# Broadcast the tensor to all processes.
barrier()
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
# Placeholder tensor for gathering results
if is_main_process():