From 6978f09760db94c75e80d3c11484592624245bdd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 13 Jul 2025 23:01:49 -0400 Subject: [PATCH] pre-patch the mlp --- src/axolotl/loaders/patch_manager.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 623943a65..99d2c5ec9 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -12,7 +12,6 @@ import transformers from transformers import PretrainedConfig, PreTrainedModel from axolotl.integrations.base import PluginManager -from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, @@ -399,6 +398,18 @@ class PatchManager: "Shifted-sparse attention not currently implemented without flash attention." ) + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + is_xformers_swiglu_available, + ) + + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_mlp_with_swiglu, + ) + + LOG.info("Patching with SwiGLU...") + patch_mlp_with_swiglu(self.cfg.model_config_type) + def _apply_llama_flash_attn_patches(self, model): """Apply LLaMA-specific flash attention patches.""" if ( @@ -409,16 +420,14 @@ class PatchManager: and not self.inference ): # TODO(MengqingCao): split these patches seperately - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - is_xformers_swiglu_available, - replace_llama_mlp_with_swiglu, + from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu, replace_llama_qkv_with_fused, ) - if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("Patching with SwiGLU...") - # replace_llama_mlp_with_swiglu(model) - patch_mlp_with_swiglu(model) + # if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + # LOG.info("Patching with SwiGLU...") + # # replace_llama_mlp_with_swiglu(model) + # patch_mlp_with_swiglu(model) if self.cfg.flash_attn_fuse_qkv: LOG.info("Patching with fused QKV...")