use new patch

This commit is contained in:
Wing Lian
2025-07-13 22:40:37 -04:00
parent 1649f91cd4
commit d41b3814d0
2 changed files with 4 additions and 2 deletions

View File

@@ -12,6 +12,7 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -416,7 +417,8 @@ class PatchManager:
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("Patching with SwiGLU...")
replace_llama_mlp_with_swiglu(model)
# replace_llama_mlp_with_swiglu(model)
patch_mlp_with_swiglu(model)
if self.cfg.flash_attn_fuse_qkv:
LOG.info("Patching with fused QKV...")

View File

@@ -84,7 +84,7 @@ def replace_llama_mlp_with_swiglu(model):
def patch_mlp_with_swiglu(model_type):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
from axolotl.monkeypatch.xformers_ import FusedMLPv2 as FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")