pre-patch the mlp

This commit is contained in:
Wing Lian
2025-07-13 23:01:49 -04:00
parent d41b3814d0
commit 6978f09760

View File

@@ -12,7 +12,6 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
from axolotl.monkeypatch.multipack import ( from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES, SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack, patch_for_multipack,
@@ -399,6 +398,18 @@ class PatchManager:
"Shifted-sparse attention not currently implemented without flash attention." "Shifted-sparse attention not currently implemented without flash attention."
) )
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_mlp_with_swiglu,
)
LOG.info("Patching with SwiGLU...")
patch_mlp_with_swiglu(self.cfg.model_config_type)
def _apply_llama_flash_attn_patches(self, model): def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches.""" """Apply LLaMA-specific flash attention patches."""
if ( if (
@@ -409,16 +420,14 @@ class PatchManager:
and not self.inference and not self.inference
): ):
# TODO(MengqingCao): split these patches seperately # TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu,
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused, replace_llama_qkv_with_fused,
) )
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): # if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...") # LOG.info("Patching with SwiGLU...")
# replace_llama_mlp_with_swiglu(model) # # replace_llama_mlp_with_swiglu(model)
patch_mlp_with_swiglu(model) # patch_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv: if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...") LOG.info("Patching with fused QKV...")