diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 7156adec0..67facd607 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -96,7 +96,7 @@ def load_model( ) if cfg.is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: + if cfg.device not in ["mps", "cpu"] and not cfg.inference: from axolotl.flash_attn import replace_llama_attn_with_flash_attn logging.info("patching with flash attention")