getting better
This commit is contained in:
@@ -1,16 +1,21 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
|
import accelerate
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import tensor_parallel as tp
|
import tensor_parallel as tp
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
import transformers.utils.bitsandbytes
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||||
from peft.tuners.lora import QuantLinear
|
from peft.tuners.lora import QuantLinear
|
||||||
|
from safetensors.torch import load_file as load_safetensors_file
|
||||||
from transformers import ( # noqa: F401
|
from transformers import ( # noqa: F401
|
||||||
AddedToken,
|
AddedToken,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
@@ -222,7 +227,7 @@ def load_model(
|
|||||||
load_in_4bit=True,
|
load_in_4bit=True,
|
||||||
llm_int8_threshold=6.0,
|
llm_int8_threshold=6.0,
|
||||||
llm_int8_has_fp16_weight=False,
|
llm_int8_has_fp16_weight=False,
|
||||||
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
bnb_4bit_compute_dtype=torch.float16,
|
||||||
bnb_4bit_use_double_quant=True,
|
bnb_4bit_use_double_quant=True,
|
||||||
bnb_4bit_quant_type="nf4",
|
bnb_4bit_quant_type="nf4",
|
||||||
)
|
)
|
||||||
@@ -327,27 +332,15 @@ def load_model(
|
|||||||
base_model,
|
base_model,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
)
|
)
|
||||||
# with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
# model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
# config=config,
|
config=config,
|
||||||
# trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
# ).half()
|
).half()
|
||||||
# model = tp.TensorParallelPreTrainedModel(
|
model = tp.TensorParallelPreTrainedModel(
|
||||||
# model,
|
model,
|
||||||
# sharded=False,
|
sharded=False,
|
||||||
# )
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
base_model,
|
|
||||||
config=config,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
|
||||||
torch_dtype=cfg.torch_dtype,
|
|
||||||
device_map={"": "cpu"},
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
)
|
||||||
model = tp.tensor_parallel(model, sharded=False)
|
|
||||||
else:
|
else:
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -474,6 +467,52 @@ def load_model(
|
|||||||
|
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
|
if cfg.tensor_parallel and cfg.adapter == "qlora":
|
||||||
|
device_map = tp.infer_sharded_device_map(model)
|
||||||
|
|
||||||
|
load_file = torch.load
|
||||||
|
try:
|
||||||
|
with open(
|
||||||
|
hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r"
|
||||||
|
) as index_file:
|
||||||
|
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
||||||
|
except:
|
||||||
|
with open(
|
||||||
|
hf_hub_download(base_model, "model.safetensors.index.json"), "r"
|
||||||
|
) as index_file:
|
||||||
|
shard_filenames = set(json.load(index_file)["weight_map"].values())
|
||||||
|
load_file = load_safetensors_file
|
||||||
|
|
||||||
|
for shard_filename in sorted(shard_filenames):
|
||||||
|
# Download a shard
|
||||||
|
shard_path = hf_hub_download(base_model, shard_filename)
|
||||||
|
|
||||||
|
# Convert model shard
|
||||||
|
converted_state_dict = (
|
||||||
|
tp.convert_state_dict( # <- tensor_parallel helper function.
|
||||||
|
load_file(
|
||||||
|
shard_path
|
||||||
|
), # Creates a tensor_parallel checkpoint form a normal one
|
||||||
|
model.tensor_parallel_config,
|
||||||
|
world_size=torch.cuda.device_count(),
|
||||||
|
for_pretrained=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dispatch the shard
|
||||||
|
for param_name, param in converted_state_dict.items():
|
||||||
|
module_name = param_name
|
||||||
|
|
||||||
|
while len(module_name) > 0 and module_name not in device_map:
|
||||||
|
module_name = ".".join(module_name.split(".")[:-1])
|
||||||
|
param_device = device_map[module_name]
|
||||||
|
|
||||||
|
transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device(
|
||||||
|
model, param_name, param_device, value=param.to(dtype=torch.float16)
|
||||||
|
)
|
||||||
|
converted_state_dict[param_name] = None
|
||||||
|
del converted_state_dict
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user