use new patch
This commit is contained in:
@@ -12,6 +12,7 @@ import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
@@ -416,7 +417,8 @@ class PatchManager:
|
||||
|
||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
# replace_llama_mlp_with_swiglu(model)
|
||||
patch_mlp_with_swiglu(model)
|
||||
|
||||
if self.cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("Patching with fused QKV...")
|
||||
|
||||
@@ -84,7 +84,7 @@ def replace_llama_mlp_with_swiglu(model):
|
||||
|
||||
def patch_mlp_with_swiglu(model_type):
|
||||
if is_xformers_swiglu_available():
|
||||
from axolotl.monkeypatch.xformers_ import FusedMLP
|
||||
from axolotl.monkeypatch.xformers_ import FusedMLPv2 as FusedMLP
|
||||
else:
|
||||
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user