Add Ascend NPU support (#1758)
This commit is contained in:
@@ -4,6 +4,9 @@ import functools
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import get_device_type
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
@@ -53,6 +56,12 @@ def mps_memory_usage_all():
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
def npu_memory_usage_all(device=0):
|
||||
usage = torch.npu.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.npu.memory_reserved(device) / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
@@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
cur_device = get_device_type()
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in str(cur_device) and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
@@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device):
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return usage, cache, misc
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
@@ -29,7 +30,10 @@ def choose_device(cfg):
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
if is_torch_npu_available():
|
||||
return f"npu:{cfg.local_rank}"
|
||||
|
||||
raise SystemError("No CUDA/mps/npu device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
@@ -39,6 +43,8 @@ def choose_device(cfg):
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
elif cfg.device.startswith("npu"):
|
||||
cfg.device_map = {"npu": torch.npu.current_device()}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from pydantic import (
|
||||
)
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
|
||||
@@ -1433,6 +1434,40 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_npu_config(cls, data):
|
||||
if is_torch_npu_available():
|
||||
# check attention config
|
||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||
for attn in attn_list:
|
||||
if data.get(attn):
|
||||
raise NotImplementedError(
|
||||
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
|
||||
)
|
||||
|
||||
# check quant config
|
||||
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
|
||||
optimizer = data.get("optimizer")
|
||||
raise NotImplementedError(
|
||||
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
|
||||
)
|
||||
|
||||
quant_list = ["load_in_8bit", "load_in_4bit"]
|
||||
for quant in quant_list:
|
||||
if data.get(quant):
|
||||
raise NotImplementedError(
|
||||
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
|
||||
)
|
||||
|
||||
# check dtype config
|
||||
if data.get("tf32"):
|
||||
raise NotImplementedError(
|
||||
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -9,10 +9,44 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
distributed_state = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_device_type():
|
||||
device = torch.device("cpu")
|
||||
if is_torch_cuda_available():
|
||||
device = torch.device("cuda")
|
||||
elif is_torch_mps_available():
|
||||
device = torch.device("mps")
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
return device
|
||||
|
||||
|
||||
def get_device_count():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.device_count()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.device_count()
|
||||
return 1
|
||||
|
||||
|
||||
def get_current_device():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.current_device()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.current_device()
|
||||
return 0
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
@@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
if not is_distributed():
|
||||
return [value_scalar]
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
).float()
|
||||
|
||||
if not is_main_process():
|
||||
@@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
|
||||
if not is_distributed():
|
||||
return vals
|
||||
|
||||
cur_device = get_device_type()
|
||||
if is_main_process():
|
||||
data_byte = pickle.dumps(vals)
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
|
||||
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
|
||||
else:
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||
data_size = torch.IntTensor([0]).to("cuda")
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
|
||||
data_size = torch.IntTensor([0]).to(cur_device)
|
||||
|
||||
dist.broadcast(data_size, 0)
|
||||
if not is_main_process():
|
||||
@@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
||||
Returns:
|
||||
- The computed value (int or float).
|
||||
"""
|
||||
cur_device = f"{get_device_type()}:{get_current_device()}"
|
||||
if is_main_process():
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
value_scalar, device=cur_device, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
value_tensor = torch.tensor(
|
||||
0.0, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
0.0, device=cur_device, dtype=torch.float32
|
||||
) # Placeholder tensor
|
||||
|
||||
# Broadcast the tensor to all processes.
|
||||
@@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||
"""
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
).float()
|
||||
|
||||
# Placeholder tensor for gathering results
|
||||
|
||||
@@ -55,7 +55,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
@@ -570,7 +570,8 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
max_memory = {}
|
||||
for i in range(torch.cuda.device_count()):
|
||||
num_device = get_device_count()
|
||||
for i in range(num_device):
|
||||
max_memory[i] = gpu_memory_limit
|
||||
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
|
||||
|
||||
@@ -595,8 +596,11 @@ class ModelLoader:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
elif "npu" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "npu:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
@@ -1050,7 +1054,11 @@ class ModelLoader:
|
||||
self.ajust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
"cuda",
|
||||
"mps",
|
||||
"npu",
|
||||
):
|
||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -1118,9 +1126,9 @@ class ModelLoader:
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
self.model.to(f"cuda:{self.cfg.local_rank}")
|
||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
setattr(self.model, "is_parallelizable", True)
|
||||
setattr(self.model, "model_parallel", True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user