diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 19b13a342..ad1432583 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -136,7 +136,11 @@ def load_model( replace_stablelm_attn_with_flash_attn(cfg.base_model) - if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: + if ( + cfg.is_llama_derived_model + and cfg.flash_attention + and (cfg.noisy_embeddings_alpha or cfg.sample_packing) + ): if cfg.device not in ["mps", "cpu"] and not inference: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn,