From 65f3a4f7032f110c50cab103cdfbe4f5916f2074 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Sep 2023 00:47:01 -0400 Subject: [PATCH] tensor-parallel support --- requirements.txt | 1 + src/axolotl/utils/models.py | 45 +++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2b60ef14d..fcb6e69d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ scikit-learn==1.2.2 pynvml art fschat==0.2.29 +tensor_parallel diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index cc83840ba..f61b27a90 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -4,7 +4,9 @@ import math import os from typing import Optional, Tuple # noqa: F401 +import accelerate import bitsandbytes as bnb +import tensor_parallel as tp import torch import transformers from optimum.bettertransformer import BetterTransformer @@ -235,7 +237,12 @@ def load_model( model_kwargs["use_flash_attention_2"] = True try: - if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: + if ( + cfg.is_llama_derived_model + and not cfg.trust_remote_code + and not cfg.gptq + and not cfg.tensor_parallel + ): from transformers import LlamaForCausalLM config_kwargs = {} @@ -301,7 +308,7 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) - elif model_type and not cfg.trust_remote_code: + elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel: if cfg.gptq: model = AutoModelForCausalLM.from_pretrained( base_model, @@ -316,6 +323,19 @@ def load_model( trust_remote_code=cfg.trust_remote_code or False, **model_kwargs, ) + elif cfg.tensor_parallel: + config = AutoConfig.from_pretrained( + base_model, + trust_remote_code=cfg.trust_remote_code or False, + ) + with accelerate.init_empty_weights(): + model = AutoModelForCausalLM.from_config( + config=config, + trust_remote_code=cfg.trust_remote_code or False, + ).half() + model = tp.TensorParallelPreTrainedModel( + model, + ) else: config = AutoConfig.from_pretrained( base_model, @@ -478,6 +498,8 @@ def load_adapter(model, cfg, adapter, inference=False): return model, None if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() + if adapter == "qlora" and cfg.tensor_parallel: + return load_tp_qlora(model) if adapter in ["lora", "qlora"]: return load_lora(model, cfg, inference=inference) if adapter == "llama-adapter": @@ -529,6 +551,25 @@ def find_all_linear_names(model): return list(lora_module_names) +def load_tp_qlora(model): + from transformers.utils.bitsandbytes import replace_with_bnb_linear + + model = replace_with_bnb_linear( + model, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + ), + ) + model.is_loaded_in_4bit = True + + return model, None + + def load_lora(model, cfg, inference=False): # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]