diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 70e36714c..209933482 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -82,6 +82,28 @@ def replace_llama_mlp_with_swiglu(model): set_module_name(model, name, mlp) +def patch_mlp_with_swiglu(model_type): + if is_xformers_swiglu_available(): + from axolotl.monkeypatch.xformers_ import FusedMLP + else: + raise RuntimeError("xformers SwiGLU not available for this environment") + + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) + _ = getattr(module, f"{model_cls_prefix}MLP") + setattr(module, f"{model_cls_prefix}MLP", FusedMLP) + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import MLP class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e + + def replace_llama_qkv_with_fused(model): for name, module in model.named_modules(): if isinstance(module, LlamaAttention):