remove unused code, support adapter for tensor parallel

This commit is contained in:
Wing Lian
2023-11-01 20:31:51 -04:00
parent b3689f73e3
commit 026172eaa8

View File

@@ -1,20 +1,16 @@
"""Module for models and model loading"""
import json
import logging
import math
import os
from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import transformers
import transformers.utils.bitsandbytes
from huggingface_hub import hf_hub_download
from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from safetensors.torch import load_file as load_safetensors_file
from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
@@ -327,11 +323,15 @@ def load_model(
**model_kwargs,
)
elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=cfg.torch_dtype,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True,
offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
config = AutoConfig.from_pretrained(
@@ -459,57 +459,6 @@ def load_model(
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.tensor_parallel and cfg.adapter == "qlora":
device_map = tp.infer_sharded_device_map(model)
load_file = torch.load
try:
with open(
hf_hub_download(base_model, "pytorch_model.bin.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
except Exception as err: # pylint: disable=broad-exception-caught
LOG.warning(err)
with open(
hf_hub_download(base_model, "model.safetensors.index.json"),
"r",
encoding="utf=8",
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
load_file = load_safetensors_file
for shard_filename in sorted(shard_filenames):
# Download a shard
shard_path = hf_hub_download(base_model, shard_filename)
# Convert model shard
converted_state_dict = (
tp.convert_state_dict( # <- tensor_parallel helper function.
load_file(
shard_path
), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=torch.cuda.device_count(),
for_pretrained=True,
)
)
# Dispatch the shard
for param_name, param in converted_state_dict.items():
module_name = param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
param_device = device_map[module_name]
transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param.to(dtype=torch.float16)
)
converted_state_dict[param_name] = None
del converted_state_dict
if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}")