Add Ascend NPU support (#1758)

This commit is contained in:
Mengqing Cao
2024-11-21 10:28:41 +08:00
committed by GitHub
parent 2e99bb303e
commit 838b74d05b
5 changed files with 114 additions and 16 deletions

View File

@@ -5,6 +5,7 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
@@ -29,7 +30,10 @@ def choose_device(cfg):
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"
raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
@@ -39,6 +43,8 @@ def choose_device(cfg):
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}