Fix undefined LlamaForCausalLM and del try except

This commit is contained in:
NanoCode012
2023-06-11 11:58:31 +09:00
parent e285e24f7f
commit 563b6d89e6

View File

@@ -81,7 +81,6 @@ def load_model(
Load a model from a base model and a model type.
"""
global LlamaForCausalLM # pylint: disable=global-statement
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = "llama" in base_model or (
@@ -203,12 +202,7 @@ def load_model(
)
load_in_8bit = False
elif cfg.is_llama_derived_model:
try:
from transformers import LlamaForCausalLM
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(