diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 2544429e6..623943a65 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -12,6 +12,7 @@ import transformers from transformers import PretrainedConfig, PreTrainedModel from axolotl.integrations.base import PluginManager +from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, @@ -416,7 +417,8 @@ class PatchManager: if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): LOG.info("Patching with SwiGLU...") - replace_llama_mlp_with_swiglu(model) + # replace_llama_mlp_with_swiglu(model) + patch_mlp_with_swiglu(model) if self.cfg.flash_attn_fuse_qkv: LOG.info("Patching with fused QKV...") diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 209933482..d77b2b4b6 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -84,7 +84,7 @@ def replace_llama_mlp_with_swiglu(model): def patch_mlp_with_swiglu(model_type): if is_xformers_swiglu_available(): - from axolotl.monkeypatch.xformers_ import FusedMLP + from axolotl.monkeypatch.xformers_ import FusedMLPv2 as FusedMLP else: raise RuntimeError("xformers SwiGLU not available for this environment")