From bb991fd870287df0602660878497d3dac49f20a2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 19 Apr 2023 13:15:33 -0400 Subject: [PATCH] fix bug when model_type not explicitly passed --- src/axolotl/utils/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d66831861..4f9bdfc0b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -35,7 +35,7 @@ def load_model( # TODO refactor as a kwarg load_in_8bit = cfg.load_in_8bit tokenizer = None - is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower() + is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower()) if is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and inference is False: