Add Ascend NPU support (#1758)
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -29,7 +30,10 @@ def choose_device(cfg):
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
if is_torch_npu_available():
|
||||
return f"npu:{cfg.local_rank}"
|
||||
|
||||
raise SystemError("No CUDA/mps/npu device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
@@ -39,6 +43,8 @@ def choose_device(cfg):
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
elif cfg.device.startswith("npu"):
|
||||
cfg.device_map = {"npu": torch.npu.current_device()}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user