fix sharegpt handling from hf, don't worry about loading llama if using earlier transformers release
This commit is contained in:
@@ -7,11 +7,16 @@ import torch
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
)
|
||||
try:
|
||||
from transformers import (
|
||||
LlamaForCausalLM,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
except:
|
||||
logging.warning("This version of transformers does not support Llama. Consider upgrading.")
|
||||
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||
|
||||
@@ -95,7 +100,7 @@ def load_model(
|
||||
else True,
|
||||
)
|
||||
load_in_8bit = False
|
||||
elif is_llama_derived_model:
|
||||
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit,
|
||||
@@ -130,7 +135,7 @@ def load_model(
|
||||
|
||||
if not tokenizer:
|
||||
try:
|
||||
if is_llama_derived_model:
|
||||
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
||||
tokenizer = LlamaTokenizer.from_pretrained(model)
|
||||
else:
|
||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
|
||||
|
||||
Reference in New Issue
Block a user