fix device map

This commit is contained in:
Wing Lian
2023-06-02 14:29:08 -04:00
parent 76a70fd739
commit 74ebbf4371

View File

@@ -47,10 +47,11 @@ def choose_device(cfg):
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]: