pre-patch the mlp
This commit is contained in:
@@ -12,7 +12,6 @@ import transformers
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
@@ -399,6 +398,18 @@ class PatchManager:
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_mlp_with_swiglu,
|
||||
)
|
||||
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
patch_mlp_with_swiglu(self.cfg.model_config_type)
|
||||
|
||||
def _apply_llama_flash_attn_patches(self, model):
|
||||
"""Apply LLaMA-specific flash attention patches."""
|
||||
if (
|
||||
@@ -409,16 +420,14 @@ class PatchManager:
|
||||
and not self.inference
|
||||
):
|
||||
# TODO(MengqingCao): split these patches seperately
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
is_xformers_swiglu_available,
|
||||
replace_llama_mlp_with_swiglu,
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
LOG.info("Patching with SwiGLU...")
|
||||
# replace_llama_mlp_with_swiglu(model)
|
||||
patch_mlp_with_swiglu(model)
|
||||
# if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||
# LOG.info("Patching with SwiGLU...")
|
||||
# # replace_llama_mlp_with_swiglu(model)
|
||||
# patch_mlp_with_swiglu(model)
|
||||
|
||||
if self.cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("Patching with fused QKV...")
|
||||
|
||||
Reference in New Issue
Block a user