pre-patch the mlp

This commit is contained in:
Wing Lian
2025-07-13 23:01:49 -04:00
parent d41b3814d0
commit 6978f09760

View File

@@ -12,7 +12,6 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -399,6 +398,18 @@ class PatchManager:
"Shifted-sparse attention not currently implemented without flash attention."
)
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_mlp_with_swiglu,
)
LOG.info("Patching with SwiGLU...")
patch_mlp_with_swiglu(self.cfg.model_config_type)
def _apply_llama_flash_attn_patches(self, model):
"""Apply LLaMA-specific flash attention patches."""
if (
@@ -409,16 +420,14 @@ class PatchManager:
and not self.inference
):
# TODO(MengqingCao): split these patches seperately
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
# replace_llama_mlp_with_swiglu(model)
patch_mlp_with_swiglu(model)
# if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
# LOG.info("Patching with SwiGLU...")
# # replace_llama_mlp_with_swiglu(model)
# patch_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...")