fix device map
This commit is contained in:
@@ -47,10 +47,11 @@ def choose_device(cfg):
|
|||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
cfg.device = get_device()
|
cfg.device = get_device()
|
||||||
if cfg.device == "cuda":
|
if cfg.device_map != "auto":
|
||||||
cfg.device_map = {"": cfg.local_rank}
|
if cfg.device.startswith("cuda"):
|
||||||
else:
|
cfg.device_map = {"": cfg.local_rank}
|
||||||
cfg.device_map = {"": cfg.device}
|
else:
|
||||||
|
cfg.device_map = {"": cfg.device}
|
||||||
|
|
||||||
|
|
||||||
def get_multi_line_input() -> Optional[str]:
|
def get_multi_line_input() -> Optional[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user