fix sharegpt handling from hf, don't worry about loading llama if using earlier transformers release

This commit is contained in:
Wing Lian
2023-04-20 09:19:46 -04:00
parent 8e2a5609b3
commit 8d437853c8
4 changed files with 29 additions and 7 deletions

View File

@@ -7,11 +7,16 @@ import torch
import transformers
from transformers import (
AutoModelForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
AutoTokenizer,
PreTrainedModel,
)
try:
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
)
except:
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -95,7 +100,7 @@ def load_model(
else True,
)
load_in_8bit = False
elif is_llama_derived_model:
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
@@ -130,7 +135,7 @@ def load_model(
if not tokenizer:
try:
if is_llama_derived_model:
if is_llama_derived_model and "LlamaTokenizer" in globals():
tokenizer = LlamaTokenizer.from_pretrained(model)
else:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)