diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d66831861..4f9bdfc0b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -35,7 +35,7 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit tokenizer = None - is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower() + is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower()) if is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: