fix bug when model_type not explicitly passed

This commit is contained in:
Wing Lian
2023-04-19 13:15:33 -04:00
parent d65385912e
commit bb991fd870

View File

@@ -35,7 +35,7 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
tokenizer = None
is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower()
is_llama_derived_model = "llama" in base_model or (cfg.model_type and "llama" in cfg.model_type.lower())
if is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and inference is False: