Fix missing cfg.
This commit is contained in:
committed by
GitHub
parent
79e2a6f140
commit
a808bf913f
@@ -96,7 +96,7 @@ def load_model(
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and inference is False:
|
||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||
|
||||
logging.info("patching with flash attention")
|
||||
|
||||
Reference in New Issue
Block a user