diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index b84597076..b3a5eeb60 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -202,7 +202,7 @@ def load_model( else True, ) load_in_8bit = False - elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals(): + elif cfg.is_llama_derived_model: try: from transformers import LlamaForCausalLM except ImportError: