fix distributed devices (#612)
* fix distributed devices * Update distributed.py * Update distributed.py
This commit is contained in:
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
value_scalar = fn()
|
||||
if not is_distributed():
|
||||
return [value_scalar]
|
||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
).float()
|
||||
|
||||
if not is_main_process():
|
||||
dist.gather(value_tensor, dst=0)
|
||||
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
||||
"""
|
||||
if is_main_process():
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
).float()
|
||||
else:
|
||||
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
|
||||
value_tensor = torch.tensor(
|
||||
0.0, device=torch.cuda.current_device()
|
||||
) # Placeholder tensor
|
||||
|
||||
# Broadcast the tensor to all processes.
|
||||
barrier()
|
||||
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||
"""
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
).float()
|
||||
|
||||
# Placeholder tensor for gathering results
|
||||
if is_main_process():
|
||||
|
||||
Reference in New Issue
Block a user