From e13c2fd6b1a4beb030257cecefe4eaa21288298a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Sep 2023 10:46:11 -0400 Subject: [PATCH] getting better --- src/axolotl/utils/models.py | 81 +++++++++++++++++++++++++++---------- 1 file changed, 60 insertions(+), 21 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index ee91d81ae..ccaf41f35 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,16 +1,21 @@ """Module for models and model loading""" +import json import logging import math import os from typing import Optional, Tuple # noqa: F401 +import accelerate import bitsandbytes as bnb import tensor_parallel as tp import torch import transformers +import transformers.utils.bitsandbytes +from huggingface_hub import hf_hub_download from optimum.bettertransformer import BetterTransformer from peft import PeftConfig, prepare_model_for_kbit_training from peft.tuners.lora import QuantLinear +from safetensors.torch import load_file as load_safetensors_file from transformers import ( # noqa: F401 AddedToken, AutoConfig, @@ -222,7 +227,7 @@ def load_model( load_in_4bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, - bnb_4bit_compute_dtype=cfg.torch_dtype, + bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) @@ -327,27 +332,15 @@ def load_model( base_model, trust_remote_code=cfg.trust_remote_code or False, ) - # with accelerate.init_empty_weights(): - # model = AutoModelForCausalLM.from_config( - # config=config, - # trust_remote_code=cfg.trust_remote_code or False, - # ).half() - # model = tp.TensorParallelPreTrainedModel( - # model, - # sharded=False, - # ) - model = AutoModelForCausalLM.from_pretrained( - base_model, - config=config, - trust_remote_code=cfg.trust_remote_code or False, - low_cpu_mem_usage=True, - load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, - load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, - torch_dtype=cfg.torch_dtype, - device_map={"": "cpu"}, - **model_kwargs, + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config( + config=config, + trust_remote_code=cfg.trust_remote_code or False, + ).half() + model = tp.TensorParallelPreTrainedModel( + model, + sharded=False, ) - model = tp.tensor_parallel(model, sharded=False) else: config = AutoConfig.from_pretrained( base_model, @@ -474,6 +467,52 @@ def load_model( model, lora_config = load_adapter(model, cfg, cfg.adapter) + if cfg.tensor_parallel and cfg.adapter == "qlora": + device_map = tp.infer_sharded_device_map(model) + + load_file = torch.load + try: + with open( + hf_hub_download(base_model, "pytorch_model.bin.index.json"), "r" + ) as index_file: + shard_filenames = set(json.load(index_file)["weight_map"].values()) + except: + with open( + hf_hub_download(base_model, "model.safetensors.index.json"), "r" + ) as index_file: + shard_filenames = set(json.load(index_file)["weight_map"].values()) + load_file = load_safetensors_file + + for shard_filename in sorted(shard_filenames): + # Download a shard + shard_path = hf_hub_download(base_model, shard_filename) + + # Convert model shard + converted_state_dict = ( + tp.convert_state_dict( # <- tensor_parallel helper function. + load_file( + shard_path + ), # Creates a tensor_parallel checkpoint form a normal one + model.tensor_parallel_config, + world_size=torch.cuda.device_count(), + for_pretrained=True, + ) + ) + + # Dispatch the shard + for param_name, param in converted_state_dict.items(): + module_name = param_name + + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + param_device = device_map[module_name] + + transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device( + model, param_name, param_device, value=param.to(dtype=torch.float16) + ) + converted_state_dict[param_name] = None + del converted_state_dict + if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}")