diff --git a/src/axolotl/flash_attn.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py similarity index 100% rename from src/axolotl/flash_attn.py rename to src/axolotl/monkeypatch/llama_attn_hijack_flash.py diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d4bda130c..23d7716a0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -92,7 +92,9 @@ def load_model( if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and not cfg.inference: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) LOG.info("patching with flash attention") replace_llama_attn_with_flash_attn()