tensor-parallel support
This commit is contained in:
@@ -31,3 +31,4 @@ scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
tensor_parallel
|
||||
|
||||
@@ -4,7 +4,9 @@ import math
|
||||
import os
|
||||
from typing import Optional, Tuple # noqa: F401
|
||||
|
||||
import accelerate
|
||||
import bitsandbytes as bnb
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
import transformers
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -235,7 +237,12 @@ def load_model(
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
and not cfg.tensor_parallel
|
||||
):
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
@@ -301,7 +308,7 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@@ -316,6 +323,19 @@ def load_model(
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif cfg.tensor_parallel:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
with accelerate.init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config=config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
).half()
|
||||
model = tp.TensorParallelPreTrainedModel(
|
||||
model,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
@@ -478,6 +498,8 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter == "qlora" and cfg.tensor_parallel:
|
||||
return load_tp_qlora(model)
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
if adapter == "llama-adapter":
|
||||
@@ -529,6 +551,25 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def load_tp_qlora(model):
|
||||
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
||||
|
||||
model = replace_with_bnb_linear(
|
||||
model,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
)
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
return model, None
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user