Add Ascend NPU support (#1758)

This commit is contained in:
Mengqing Cao
2024-11-21 10:28:41 +08:00
committed by GitHub
parent 2e99bb303e
commit 838b74d05b
5 changed files with 114 additions and 16 deletions

View File

@@ -55,7 +55,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -570,7 +570,8 @@ class ModelLoader:
)
max_memory = {}
for i in range(torch.cuda.device_count()):
num_device = get_device_count()
for i in range(num_device):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
@@ -595,8 +596,11 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
if torch.backends.mps.is_available():
cur_device = get_device_type()
if "mps" in str(cur_device):
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
@@ -1050,7 +1054,11 @@ class ModelLoader:
self.ajust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -1118,9 +1126,9 @@ class ModelLoader:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"cuda:{self.cfg.local_rank}")
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)