wip patch

This commit is contained in:
Wing Lian
2025-07-07 09:56:22 -04:00
parent 5a063f5c75
commit 1649f91cd4

View File

@@ -82,6 +82,28 @@ def replace_llama_mlp_with_swiglu(model):
set_module_name(model, name, mlp)
def patch_mlp_with_swiglu(model_type):
if is_xformers_swiglu_available():
from axolotl.monkeypatch.xformers_ import FusedMLP
else:
raise RuntimeError("xformers SwiGLU not available for this environment")
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
_ = getattr(module, f"{model_cls_prefix}MLP")
setattr(module, f"{model_cls_prefix}MLP", FusedMLP)
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):