Model parallel (#538)
* model-parallel for single process * fix device/device_map * fix handling for device
This commit is contained in:
@@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if not torch.cuda.is_available():
|
||||
if not torch.cuda.is_available() or device == "auto":
|
||||
return (0, 0, 0)
|
||||
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
|
||||
@@ -25,7 +25,9 @@ def choose_device(cfg):
|
||||
return "cpu"
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.device_map != "auto":
|
||||
if cfg.world_size == 1:
|
||||
cfg.device_map = "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user