diff --git a/scripts/finetune.py b/scripts/finetune.py index 1b1e994dd..9a2d62904 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -47,10 +47,11 @@ def choose_device(cfg): return "cpu" cfg.device = get_device() - if cfg.device == "cuda": - cfg.device_map = {"": cfg.local_rank} - else: - cfg.device_map = {"": cfg.device} + if cfg.device_map != "auto": + if cfg.device.startswith("cuda"): + cfg.device_map = {"": cfg.local_rank} + else: + cfg.device_map = {"": cfg.device} def get_multi_line_input() -> Optional[str]: