remove unused code, support adapter for tensor parallel

This commit is contained in:
Wing Lian
2023-11-01 20:31:51 -04:00
parent b3689f73e3
commit 026172eaa8

View File

@@ -1,20 +1,16 @@
"""Module for models and model loading""" """Module for models and model loading"""
import json
import logging import logging
import math import math
import os import os
from typing import Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import tensor_parallel as tp
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes import transformers.utils.bitsandbytes
from huggingface_hub import hf_hub_download
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
from safetensors.torch import load_file as load_safetensors_file
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
@@ -327,11 +323,15 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
elif cfg.tensor_parallel: elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
torch_dtype=cfg.torch_dtype, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
offload_state_dict=True, offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
) )
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
@@ -459,57 +459,6 @@ def load_model(
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.tensor_parallel and cfg.adapter == "qlora":
device_map = tp.infer_sharded_device_map(model)
load_file = torch.load
try:
with open(
hf_hub_download(base_model, "pytorch_model.bin.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
except Exception as err: # pylint: disable=broad-exception-caught
LOG.warning(err)
with open(
hf_hub_download(base_model, "model.safetensors.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
load_file = load_safetensors_file
for shard_filename in sorted(shard_filenames):
# Download a shard
shard_path = hf_hub_download(base_model, shard_filename)
# Convert model shard
converted_state_dict = (
tp.convert_state_dict( # <- tensor_parallel helper function.
load_file(
shard_path
), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=torch.cuda.device_count(),
for_pretrained=True,
)
)
# Dispatch the shard
for param_name, param in converted_state_dict.items():
module_name = param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
param_device = device_map[module_name]
transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param.to(dtype=torch.float16)
)
converted_state_dict[param_name] = None
del converted_state_dict
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")