tensor-parallel support

This commit is contained in:
Wing Lian
2023-09-08 00:47:01 -04:00
parent 10388a8daf
commit 65f3a4f703
2 changed files with 44 additions and 2 deletions

View File

@@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
tensor_parallel

View File

@@ -4,7 +4,9 @@ import math
import os
from typing import Optional, Tuple # noqa: F401
import accelerate
import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
@@ -235,7 +237,12 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if (
cfg.is_llama_derived_model
and not cfg.trust_remote_code
and not cfg.gptq
and not cfg.tensor_parallel
):
from transformers import LlamaForCausalLM
config_kwargs = {}
@@ -301,7 +308,7 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code:
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -316,6 +323,19 @@ def load_model(
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
elif cfg.tensor_parallel:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(
config=config,
trust_remote_code=cfg.trust_remote_code or False,
).half()
model = tp.TensorParallelPreTrainedModel(
model,
)
else:
config = AutoConfig.from_pretrained(
base_model,
@@ -478,6 +498,8 @@ def load_adapter(model, cfg, adapter, inference=False):
return model, None
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if adapter == "qlora" and cfg.tensor_parallel:
return load_tp_qlora(model)
if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter":
@@ -529,6 +551,25 @@ def find_all_linear_names(model):
return list(lora_module_names)
def load_tp_qlora(model):
from transformers.utils.bitsandbytes import replace_with_bnb_linear
model = replace_with_bnb_linear(
model,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
model.is_loaded_in_4bit = True
return model, None
def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]