fix distributed devices (#612)
* fix distributed devices * Update distributed.py * Update distributed.py
This commit is contained in:
@@ -77,7 +77,9 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
if not is_distributed():
|
if not is_distributed():
|
||||||
return [value_scalar]
|
return [value_scalar]
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(
|
||||||
|
value_scalar, device=torch.cuda.current_device()
|
||||||
|
).float()
|
||||||
|
|
||||||
if not is_main_process():
|
if not is_main_process():
|
||||||
dist.gather(value_tensor, dst=0)
|
dist.gather(value_tensor, dst=0)
|
||||||
@@ -137,9 +139,13 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(
|
||||||
|
value_scalar, device=torch.cuda.current_device()
|
||||||
|
).float()
|
||||||
else:
|
else:
|
||||||
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
|
value_tensor = torch.tensor(
|
||||||
|
0.0, device=torch.cuda.current_device()
|
||||||
|
) # Placeholder tensor
|
||||||
|
|
||||||
# Broadcast the tensor to all processes.
|
# Broadcast the tensor to all processes.
|
||||||
barrier()
|
barrier()
|
||||||
@@ -164,7 +170,9 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
|||||||
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
- A list of computed values from all ranks if on the gathering rank, otherwise None.
|
||||||
"""
|
"""
|
||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
|
value_tensor = torch.tensor(
|
||||||
|
value_scalar, device=torch.cuda.current_device()
|
||||||
|
).float()
|
||||||
|
|
||||||
# Placeholder tensor for gathering results
|
# Placeholder tensor for gathering results
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
|
|||||||
Reference in New Issue
Block a user