diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 623943a65..99d2c5ec9 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -12,7 +12,6 @@ import transformers from transformers import PretrainedConfig, PreTrainedModel from axolotl.integrations.base import PluginManager -from axolotl.monkeypatch.llama_attn_hijack_flash import patch_mlp_with_swiglu from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, @@ -399,6 +398,18 @@ class PatchManager: "Shifted-sparse attention not currently implemented without flash attention." ) + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + is_xformers_swiglu_available, + ) + + if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_mlp_with_swiglu, + ) + + LOG.info("Patching with SwiGLU...") + patch_mlp_with_swiglu(self.cfg.model_config_type) + def _apply_llama_flash_attn_patches(self, model): """Apply LLaMA-specific flash attention patches.""" if ( @@ -409,16 +420,14 @@ class PatchManager: and not self.inference ): # TODO(MengqingCao): split these patches seperately - from axolotl.monkeypatch.llama_attn_hijack_flash import ( - is_xformers_swiglu_available, - replace_llama_mlp_with_swiglu, + from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu, replace_llama_qkv_with_fused, ) - if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): - LOG.info("Patching with SwiGLU...") - # replace_llama_mlp_with_swiglu(model) - patch_mlp_with_swiglu(model) + # if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): + # LOG.info("Patching with SwiGLU...") + # # replace_llama_mlp_with_swiglu(model) + # patch_mlp_with_swiglu(model) if self.cfg.flash_attn_fuse_qkv: LOG.info("Patching with fused QKV...")