getting better

This commit is contained in:
Wing Lian
2023-09-08 10:46:11 -04:00
parent 8a21e14a21
commit e13c2fd6b1

View File

@@ -1,16 +1,21 @@
"""Module for models and model loading""" """Module for models and model loading"""
import json
import logging import logging
import math import math
import os import os
from typing import Optional, Tuple # noqa: F401 from typing import Optional, Tuple # noqa: F401
import accelerate
import bitsandbytes as bnb import bitsandbytes as bnb
import tensor_parallel as tp import tensor_parallel as tp
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes
from huggingface_hub import hf_hub_download
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
from safetensors.torch import load_file as load_safetensors_file
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
AutoConfig, AutoConfig,
@@ -222,7 +227,7 @@ def load_model(
load_in_4bit=True, load_in_4bit=True,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=cfg.torch_dtype, bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
@@ -327,27 +332,15 @@ def load_model(
base_model, base_model,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
) )
# with accelerate.init_empty_weights(): with accelerate.init_empty_weights():
# model = AutoModelForCausalLM.from_config( model = AutoModelForCausalLM.from_config(
# config=config, config=config,
# trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
# ).half() ).half()
# model = tp.TensorParallelPreTrainedModel( model = tp.TensorParallelPreTrainedModel(
# model, model,
# sharded=False, sharded=False,
# )
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=config,
trust_remote_code=cfg.trust_remote_code or False,
low_cpu_mem_usage=True,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
device_map={"": "cpu"},
**model_kwargs,
) )
model = tp.tensor_parallel(model, sharded=False)
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
@@ -474,6 +467,52 @@ def load_model(
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.tensor_parallel and cfg.adapter == "qlora":
device_map = tp.infer_sharded_device_map(model)
load_file = torch.load
try:
with open(
hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r"
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
except:
with open(
hf_hub_download(base_model, "model.safetensors.index.json"), "r"
) as index_file:
shard_filenames = set(json.load(index_file)["weight_map"].values())
load_file = load_safetensors_file
for shard_filename in sorted(shard_filenames):
# Download a shard
shard_path = hf_hub_download(base_model, shard_filename)
# Convert model shard
converted_state_dict = (
tp.convert_state_dict( # <- tensor_parallel helper function.
load_file(
shard_path
), # Creates a tensor_parallel checkpoint form a normal one
model.tensor_parallel_config,
world_size=torch.cuda.device_count(),
for_pretrained=True,
)
)
# Dispatch the shard
for param_name, param in converted_state_dict.items():
module_name = param_name
while len(module_name) > 0 and module_name not in device_map:
module_name = ".".join(module_name.split(".")[:-1])
param_device = device_map[module_name]
transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param.to(dtype=torch.float16)
)
converted_state_dict[param_name] = None
del converted_state_dict
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")