chore: Clean up repetitive model kwargs (#670)

This commit is contained in:
NanoCode012
2023-10-04 20:41:26 +09:00
committed by GitHub
parent 697c50d408
commit e62d5901b5

View File

@@ -176,6 +176,10 @@ def load_model(
hijack_expand_mask() hijack_expand_mask()
model_kwargs = {} model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision: if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.model_revision
if cfg.gptq: if cfg.gptq:
@@ -206,6 +210,7 @@ def load_model(
or cfg.is_mistral_derived_model or cfg.is_mistral_derived_model
): ):
model_kwargs["use_flash_attention_2"] = True model_kwargs["use_flash_attention_2"] = True
try: try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
@@ -220,10 +225,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs, **model_kwargs,
) )
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -257,28 +260,22 @@ def load_model(
model = MixFormerSequentialForCausalLM.from_pretrained( model = MixFormerSequentialForCausalLM.from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -307,8 +304,6 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -316,10 +311,8 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=config,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -330,10 +323,8 @@ def load_model(
LOG.exception(err) LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )