Add Ascend NPU support (#1758)

This commit is contained in:
Mengqing Cao
2024-11-21 10:28:41 +08:00
committed by GitHub
parent 2e99bb303e
commit 838b74d05b
5 changed files with 114 additions and 16 deletions

View File

@@ -4,6 +4,9 @@ import functools
import pynvml
import torch
from pynvml.nvml import NVMLError
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import get_device_type
def check_cuda_device(default_value):
@@ -53,6 +56,12 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0
def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
@@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
cur_device = get_device_type()
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
@@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2,
)
return usage, cache, misc

View File

@@ -5,6 +5,7 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
@@ -29,7 +30,10 @@ def choose_device(cfg):
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"
raise SystemError("No CUDA/mps/npu device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
@@ -39,6 +43,8 @@ def choose_device(cfg):
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}

View File

@@ -19,6 +19,7 @@ from pydantic import (
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities
@@ -1433,6 +1434,40 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
raise NotImplementedError(
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
)
# check quant config
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
optimizer = data.get("optimizer")
raise NotImplementedError(
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
)
quant_list = ["load_in_8bit", "load_in_4bit"]
for quant in quant_list:
if data.get(quant):
raise NotImplementedError(
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
)
# check dtype config
if data.get("tf32"):
raise NotImplementedError(
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -9,10 +9,44 @@ from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
)
distributed_state = None # pylint: disable=invalid-name
def get_device_type():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device
def get_device_count():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.device_count()
if "npu" in str(cur_device):
return torch.npu.device_count()
return 1
def get_current_device():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.current_device()
if "npu" in str(cur_device):
return torch.npu.current_device()
return 0
def is_distributed():
"""
Check if distributed training is initialized.
@@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()
if not is_main_process():
@@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
if not is_distributed():
return vals
cur_device = get_device_type()
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
data_size = torch.IntTensor([0]).to(cur_device)
dist.broadcast(data_size, 0)
if not is_main_process():
@@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
Returns:
- The computed value (int or float).
"""
cur_device = f"{get_device_type()}:{get_current_device()}"
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
value_scalar, device=cur_device, dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device(), dtype=torch.float32
0.0, device=cur_device, dtype=torch.float32
) # Placeholder tensor
# Broadcast the tensor to all processes.
@@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
).float()
# Placeholder tensor for gathering results

View File

@@ -55,7 +55,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -570,7 +570,8 @@ class ModelLoader:
)
max_memory = {}
for i in range(torch.cuda.device_count()):
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
@@ -595,8 +596,11 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if torch.backends.mps.is_available():
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
@@ -1050,7 +1054,11 @@ class ModelLoader:
self.ajust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -1118,9 +1126,9 @@ class ModelLoader:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"cuda:{self.cfg.local_rank}")
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)