Model parallel (#538)
* model-parallel for single process * fix device/device_map * fix handling for device
This commit is contained in:
@@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available() or device == "auto":
|
||||||
return (0, 0, 0)
|
return (0, 0, 0)
|
||||||
|
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ def choose_device(cfg):
|
|||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.device_map != "auto":
|
if cfg.world_size == 1:
|
||||||
|
cfg.device_map = "auto"
|
||||||
|
else:
|
||||||
if cfg.device.startswith("cuda"):
|
if cfg.device.startswith("cuda"):
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user