move flash-attn monkey patch alongside the others

This commit is contained in:
Aman Karmani
2023-08-03 17:20:49 +00:00
parent 248bf90f89
commit 312a9fad07
2 changed files with 3 additions and 1 deletions

View File

@@ -92,7 +92,9 @@ def load_model(
if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
)
LOG.info("patching with flash attention")
replace_llama_attn_with_flash_attn()