From 563b6d89e6ee7fc104204ce7cd178b56c35b633d Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 11:58:31 +0900 Subject: [PATCH] Fix undefined LlamaForCausalLM and del try except --- src/axolotl/utils/models.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b3a5eeb60..43d4a6a9c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -81,7 +81,6 @@ def load_model( Load a model from a base model and a model type. """ - global LlamaForCausalLM # pylint: disable=global-statement # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit cfg.is_llama_derived_model = "llama" in base_model or ( @@ -203,12 +202,7 @@ def load_model( ) load_in_8bit = False elif cfg.is_llama_derived_model: - try: - from transformers import LlamaForCausalLM - except ImportError: - logging.warning( - "This version of transformers does not support Llama. Consider upgrading." - ) + from transformers import LlamaForCausalLM config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained(