use new patch
This commit is contained in:
@@ -12,6 +12,7 @@ import transformers
|
|||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
@@ -416,7 +417,8 @@ class PatchManager:
|
|||||||
|
|
||||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("Patching with SwiGLU...")
|
LOG.info("Patching with SwiGLU...")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
# replace_llama_mlp_with_swiglu(model)
|
||||||
|
patch_mlp_with_swiglu(model)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_qkv:
|
if self.cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("Patching with fused QKV...")
|
LOG.info("Patching with fused QKV...")
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ def replace_llama_mlp_with_swiglu(model):
|
|||||||
|
|
||||||
def patch_mlp_with_swiglu(model_type):
|
def patch_mlp_with_swiglu(model_type):
|
||||||
if is_xformers_swiglu_available():
|
if is_xformers_swiglu_available():
|
||||||
from axolotl.monkeypatch.xformers_ import FusedMLP
|
from axolotl.monkeypatch.xformers_ import FusedMLPv2 as FusedMLP
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("xformers SwiGLU not available for this environment")
|
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user