bettter handling of llama model import

This commit is contained in:
Wing Lian
2023-04-14 19:30:41 -04:00
parent 949a27be21
commit 45f77dd51e

View File

@@ -19,7 +19,7 @@ from peft import (
get_peft_model_state_dict, PeftModel,
)
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
# add src to the pythonpath so we don't need to pip install this
from transformers.trainer_pt_utils import get_parameter_names
@@ -53,16 +53,23 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
raise NotImplementedError(f"{adapter} peft adapter not available")
if "llama" in base_model:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
try:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
if "llama" in base_model:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
else:
model = getattr(transformers, model_type).from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
device_map=cfg.device_map,
)
except:
model = AutoModelForCausalLM.from_pretrained(
base_model,
@@ -72,7 +79,10 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
)
try:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
if "llama" in base_model:
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
except:
tokenizer = AutoTokenizer.from_pretrained(base_model)