From 312a9fad079f33ed4d371863b8c36b2b40c7a214 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 17:20:49 +0000 Subject: [PATCH] move flash-attn monkey patch alongside the others --- .../{flash_attn.py => monkeypatch/llama_attn_hijack_flash.py} | 0 src/axolotl/utils/models.py | 4 +++- 2 files changed, 3 insertions(+), 1 deletion(-) rename src/axolotl/{flash_attn.py => monkeypatch/llama_attn_hijack_flash.py} (100%) diff --git a/src/axolotl/flash_attn.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py similarity index 100% rename from src/axolotl/flash_attn.py rename to src/axolotl/monkeypatch/llama_attn_hijack_flash.py diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d4bda130c..23d7716a0 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -92,7 +92,9 @@ def load_model( if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.device not in ["mps", "cpu"] and not cfg.inference: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + replace_llama_attn_with_flash_attn, + ) LOG.info("patching with flash attention") replace_llama_attn_with_flash_attn()