Model parallel (#538)

* model-parallel for single process

* fix device/device_map

* fix handling for device
This commit is contained in:
Wing Lian
2023-09-13 11:45:30 -04:00
committed by GitHub
parent a4e1bb6606
commit f6060a664e
2 changed files with 4 additions and 2 deletions

View File

@@ -28,7 +28,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
if not torch.cuda.is_available() or device == "auto":
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)

View File

@@ -25,7 +25,9 @@ def choose_device(cfg):
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.world_size == 1:
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else: