wip patch
This commit is contained in:
@@ -82,6 +82,28 @@ def replace_llama_mlp_with_swiglu(model):
|
|||||||
set_module_name(model, name, mlp)
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mlp_with_swiglu(model_type):
|
||||||
|
if is_xformers_swiglu_available():
|
||||||
|
from axolotl.monkeypatch.xformers_ import FusedMLP
|
||||||
|
else:
|
||||||
|
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and MLP class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix = "".join(
|
||||||
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
|
)
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
|
_ = getattr(module, f"{model_cls_prefix}MLP")
|
||||||
|
setattr(module, f"{model_cls_prefix}MLP", FusedMLP)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import MLP class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_qkv_with_fused(model):
|
def replace_llama_qkv_with_fused(model):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, LlamaAttention):
|
if isinstance(module, LlamaAttention):
|
||||||
|
|||||||
Reference in New Issue
Block a user