pre-patch the mlp
This commit is contained in:
@@ -12,7 +12,6 @@ import transformers
|
|||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu
|
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
@@ -399,6 +398,18 @@ class PatchManager:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
is_xformers_swiglu_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_mlp_with_swiglu,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Patching with SwiGLU...")
|
||||||
|
patch_mlp_with_swiglu(self.cfg.model_config_type)
|
||||||
|
|
||||||
def _apply_llama_flash_attn_patches(self, model):
|
def _apply_llama_flash_attn_patches(self, model):
|
||||||
"""Apply LLaMA-specific flash attention patches."""
|
"""Apply LLaMA-specific flash attention patches."""
|
||||||
if (
|
if (
|
||||||
@@ -409,16 +420,14 @@ class PatchManager:
|
|||||||
and not self.inference
|
and not self.inference
|
||||||
):
|
):
|
||||||
# TODO(MengqingCao): split these patches seperately
|
# TODO(MengqingCao): split these patches seperately
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu,
|
||||||
is_xformers_swiglu_available,
|
|
||||||
replace_llama_mlp_with_swiglu,
|
|
||||||
replace_llama_qkv_with_fused,
|
replace_llama_qkv_with_fused,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
# if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("Patching with SwiGLU...")
|
# LOG.info("Patching with SwiGLU...")
|
||||||
# replace_llama_mlp_with_swiglu(model)
|
# # replace_llama_mlp_with_swiglu(model)
|
||||||
patch_mlp_with_swiglu(model)
|
# patch_mlp_with_swiglu(model)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_qkv:
|
if self.cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("Patching with fused QKV...")
|
LOG.info("Patching with fused QKV...")
|
||||||
|
|||||||
Reference in New Issue
Block a user