From 026172eaa88cef17edd9fac1324c0cfed2417e04 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 1 Nov 2023 20:31:51 -0400 Subject: [PATCH] remove unused code, support adapter for tensor parallel --- src/axolotl/utils/models.py | 61 +++---------------------------------- 1 file changed, 5 insertions(+), 56 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b7ecdf08a..2b28903fd 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,20 +1,16 @@ """Module for models and model loading""" -import json import logging import math import os from typing import Optional, Tuple # noqa: F401 import bitsandbytes as bnb -import tensor_parallel as tp import torch import transformers import transformers.utils.bitsandbytes -from huggingface_hub import hf_hub_download from optimum.bettertransformer import BetterTransformer from peft import PeftConfig, prepare_model_for_kbit_training from peft.tuners.lora import QuantLinear -from safetensors.torch import load_file as load_safetensors_file from transformers import ( # noqa: F401 AddedToken, AutoConfig, @@ -327,11 +323,15 @@ def load_model( **model_kwargs, ) elif cfg.tensor_parallel: + model_kwargs.pop("device_map") model = AutoModelForCausalLM.from_pretrained( base_model, - torch_dtype=cfg.torch_dtype, + load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, + load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, low_cpu_mem_usage=True, offload_state_dict=True, + trust_remote_code=cfg.trust_remote_code or False, + **model_kwargs, ) else: config = AutoConfig.from_pretrained( @@ -459,57 +459,6 @@ def load_model( model, lora_config = load_adapter(model, cfg, cfg.adapter) - if cfg.tensor_parallel and cfg.adapter == "qlora": - device_map = tp.infer_sharded_device_map(model) - - load_file = torch.load - try: - with open( - hf_hub_download(base_model, "pytorch_model.bin.index.json"), - "r", - encoding="utf=8", - ) as index_file: - shard_filenames = set(json.load(index_file)["weight_map"].values()) - except Exception as err: # pylint: disable=broad-exception-caught - LOG.warning(err) - with open( - hf_hub_download(base_model, "model.safetensors.index.json"), - "r", - encoding="utf=8", - ) as index_file: - shard_filenames = set(json.load(index_file)["weight_map"].values()) - load_file = load_safetensors_file - - for shard_filename in sorted(shard_filenames): - # Download a shard - shard_path = hf_hub_download(base_model, shard_filename) - - # Convert model shard - converted_state_dict = ( - tp.convert_state_dict( # <- tensor_parallel helper function. - load_file( - shard_path - ), # Creates a tensor_parallel checkpoint form a normal one - model.tensor_parallel_config, - world_size=torch.cuda.device_count(), - for_pretrained=True, - ) - ) - - # Dispatch the shard - for param_name, param in converted_state_dict.items(): - module_name = param_name - - while len(module_name) > 0 and module_name not in device_map: - module_name = ".".join(module_name.split(".")[:-1]) - param_device = device_map[module_name] - - transformers.utils.bitsandbytes.set_module_quantized_tensor_to_device( - model, param_name, param_device, value=param.to(dtype=torch.float16) - ) - converted_state_dict[param_name] = None - del converted_state_dict - if cfg.ddp and not load_in_8bit: model.to(f"cuda:{cfg.local_rank}")