Merge pull request #148 from OpenAccess-AI-Collective/fix-device-load

This commit is contained in:
Wing Lian
2023-06-02 14:37:17 -04:00
committed by GitHub

View File

@@ -47,7 +47,8 @@ def choose_device(cfg):
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
if cfg.device == "cuda": if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank} cfg.device_map = {"": cfg.local_rank}
else: else:
cfg.device_map = {"": cfg.device} cfg.device_map = {"": cfg.device}