move flash-attn monkey patch alongside the others
This commit is contained in:
@@ -92,7 +92,9 @@ def load_model(
|
|||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
replace_llama_attn_with_flash_attn,
|
||||||
|
)
|
||||||
|
|
||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
|
|||||||
Reference in New Issue
Block a user