Compare commits
4 Commits
version-de
...
fused-mlp-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6978f09760 | ||
|
|
d41b3814d0 | ||
|
|
1649f91cd4 | ||
|
|
5a063f5c75 |
@@ -398,6 +398,18 @@ class PatchManager:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
is_xformers_swiglu_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_mlp_with_swiglu,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("Patching with SwiGLU...")
|
||||||
|
patch_mlp_with_swiglu(self.cfg.model_config_type)
|
||||||
|
|
||||||
def _apply_llama_flash_attn_patches(self, model):
|
def _apply_llama_flash_attn_patches(self, model):
|
||||||
"""Apply LLaMA-specific flash attention patches."""
|
"""Apply LLaMA-specific flash attention patches."""
|
||||||
if (
|
if (
|
||||||
@@ -408,15 +420,14 @@ class PatchManager:
|
|||||||
and not self.inference
|
and not self.inference
|
||||||
):
|
):
|
||||||
# TODO(MengqingCao): split these patches seperately
|
# TODO(MengqingCao): split these patches seperately
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( # is_xformers_swiglu_available,; replace_llama_mlp_with_swiglu,
|
||||||
is_xformers_swiglu_available,
|
|
||||||
replace_llama_mlp_with_swiglu,
|
|
||||||
replace_llama_qkv_with_fused,
|
replace_llama_qkv_with_fused,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
# if self.cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
|
||||||
LOG.info("Patching with SwiGLU...")
|
# LOG.info("Patching with SwiGLU...")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
# # replace_llama_mlp_with_swiglu(model)
|
||||||
|
# patch_mlp_with_swiglu(model)
|
||||||
|
|
||||||
if self.cfg.flash_attn_fuse_qkv:
|
if self.cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("Patching with fused QKV...")
|
LOG.info("Patching with fused QKV...")
|
||||||
|
|||||||
@@ -82,6 +82,28 @@ def replace_llama_mlp_with_swiglu(model):
|
|||||||
set_module_name(model, name, mlp)
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mlp_with_swiglu(model_type):
|
||||||
|
if is_xformers_swiglu_available():
|
||||||
|
from axolotl.monkeypatch.xformers_ import FusedMLPv2 as FusedMLP
|
||||||
|
else:
|
||||||
|
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and MLP class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix = "".join(
|
||||||
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
|
)
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
|
_ = getattr(module, f"{model_cls_prefix}MLP")
|
||||||
|
setattr(module, f"{model_cls_prefix}MLP", FusedMLP)
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import MLP class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_qkv_with_fused(model):
|
def replace_llama_qkv_with_fused(model):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, LlamaAttention):
|
if isinstance(module, LlamaAttention):
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Fused MLP layer for incrementally improved training efficiency
|
Fused MLP layer for incrementally improved training efficiency
|
||||||
"""
|
"""
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||||
from xformers.ops import SwiGLU
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
@@ -50,3 +53,129 @@ class FusedMLP(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||||
return self.swiglu(x)
|
return self.swiglu(x)
|
||||||
|
|
||||||
|
class FusedMLPv2(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.swiglu = SwiGLU(
|
||||||
|
in_features=self.hidden_size,
|
||||||
|
hidden_features=self.intermediate_size,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
_pack_weights=True,
|
||||||
|
)
|
||||||
|
assert config.hidden_act == "silu"
|
||||||
|
|
||||||
|
def _convert_unpacked_to_packed_state_dict(self, unpacked_state_dict):
|
||||||
|
"""
|
||||||
|
Convert state dict from unpacked format (w1, w2, w3) to packed format (w13, w2).
|
||||||
|
"""
|
||||||
|
packed_state_dict = OrderedDict()
|
||||||
|
|
||||||
|
# Handle w1 and w3 -> w13 conversion for weights
|
||||||
|
if 'gate_proj.weight' in unpacked_state_dict and 'up_proj.weight' in unpacked_state_dict:
|
||||||
|
gate_proj_weight = unpacked_state_dict['gate_proj.weight']
|
||||||
|
up_proj_weight = unpacked_state_dict['up_proj.weight']
|
||||||
|
# Concatenate gate and up weights along output dimension (dim=0)
|
||||||
|
packed_state_dict['swiglu.w12.weight'] = torch.cat([gate_proj_weight, up_proj_weight], dim=0)
|
||||||
|
|
||||||
|
# Handle w1 and w3 -> w13 conversion for biases (if they exist)
|
||||||
|
if 'gate_proj.bias' in unpacked_state_dict and 'up_proj.bias' in unpacked_state_dict:
|
||||||
|
gate_proj_bias = unpacked_state_dict['gate_proj.bias']
|
||||||
|
up_proj_bias = unpacked_state_dict['up_proj.bias']
|
||||||
|
# Concatenate gate and up biases along dimension 0
|
||||||
|
packed_state_dict['swiglu.w12.bias'] = torch.cat([gate_proj_bias, up_proj_bias], dim=0)
|
||||||
|
|
||||||
|
# Copy down parameters as-is
|
||||||
|
if "down_proj.weight" in unpacked_state_dict:
|
||||||
|
packed_state_dict["swiglu.w3.weight"] = unpacked_state_dict['down_proj.weight']
|
||||||
|
if "down_proj.bias" in unpacked_state_dict:
|
||||||
|
packed_state_dict["swiglu.w3.bias"] = unpacked_state_dict['down_proj.bias']
|
||||||
|
|
||||||
|
for key in ['swiglu.w3.weight', 'swiglu.w3.bias']:
|
||||||
|
if key in unpacked_state_dict:
|
||||||
|
packed_state_dict[key] = unpacked_state_dict[key]
|
||||||
|
|
||||||
|
# Copy any other parameters that might exist
|
||||||
|
excluded_keys = [
|
||||||
|
'gate_proj.weight', 'gate_proj.bias',
|
||||||
|
'down_proj.weight', 'down_proj.bias',
|
||||||
|
'up_proj.weight', 'up_proj.bias',
|
||||||
|
'swiglu.w12.weight', 'swiglu.w12.bias',
|
||||||
|
'swiglu.w3.weight', 'swiglu.w3.bias',
|
||||||
|
]
|
||||||
|
for key, value in unpacked_state_dict.items():
|
||||||
|
if key not in excluded_keys:
|
||||||
|
packed_state_dict[key] = value
|
||||||
|
|
||||||
|
return packed_state_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
"""
|
||||||
|
Load state dict, handling both packed (w13) and unpacked (w1, w3) formats.
|
||||||
|
"""
|
||||||
|
# Check if this is an unpacked state dict (has w1 and w3 instead of w13)
|
||||||
|
has_unpacked_gate_up = 'gate_proj.weight' in state_dict and 'up_proj.weight' in state_dict
|
||||||
|
has_packed_swiglu = 'swiglu.w12.weight' in state_dict
|
||||||
|
|
||||||
|
if has_unpacked_gate_up and not has_packed_swiglu:
|
||||||
|
state_dict = self._convert_unpacked_to_packed_state_dict(state_dict)
|
||||||
|
|
||||||
|
return super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix='', keep_vars=False, packed=False):
|
||||||
|
"""
|
||||||
|
Return state dict in unpacked format by default for compatibility.
|
||||||
|
Set packed=True to get the internal packed format.
|
||||||
|
"""
|
||||||
|
if packed:
|
||||||
|
# Return the actual packed state dict
|
||||||
|
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
|
else:
|
||||||
|
# Return unpacked format for compatibility
|
||||||
|
return self.get_unpacked_state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
|
|
||||||
|
def get_unpacked_state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||||
|
"""
|
||||||
|
Convert current packed state dict to unpacked format for compatibility.
|
||||||
|
"""
|
||||||
|
# Get the actual packed state dict first
|
||||||
|
packed_state_dict = super().state_dict(destination=None, prefix='', keep_vars=keep_vars)
|
||||||
|
|
||||||
|
if destination is None:
|
||||||
|
destination = OrderedDict()
|
||||||
|
|
||||||
|
# Handle w13 -> w1 and w3 conversion for weights
|
||||||
|
if f'{prefix}swiglu.w12.weight' in packed_state_dict:
|
||||||
|
w13_weight = packed_state_dict[f'{prefix}swiglu.w12.weight']
|
||||||
|
hidden_dim = w13_weight.shape[0] // 2
|
||||||
|
w1_weight, w3_weight = torch.split(w13_weight, hidden_dim, dim=0)
|
||||||
|
destination[f'{prefix}gate_proj.weight'] = w1_weight if not keep_vars else w1_weight.detach().requires_grad_(w1_weight.requires_grad)
|
||||||
|
destination[f'{prefix}up_proj.weight'] = w3_weight if not keep_vars else w3_weight.detach().requires_grad_(w3_weight.requires_grad)
|
||||||
|
|
||||||
|
# Handle w13 -> w1 and w3 conversion for biases (if they exist)
|
||||||
|
if f'{prefix}swiglu.w12.bias' in packed_state_dict:
|
||||||
|
w13_bias = packed_state_dict[f'{prefix}swiglu.w12.bias']
|
||||||
|
hidden_dim = w13_bias.shape[0] // 2
|
||||||
|
w1_bias, w3_bias = torch.split(w13_bias, hidden_dim, dim=0)
|
||||||
|
destination[f'{prefix}gate_proj.bias'] = w1_bias if not keep_vars else w1_bias.detach().requires_grad_(w1_bias.requires_grad)
|
||||||
|
destination[f'{prefix}up_proj.bias'] = w3_bias if not keep_vars else w3_bias.detach().requires_grad_(w3_bias.requires_grad)
|
||||||
|
|
||||||
|
# Copy w2 parameters as-is
|
||||||
|
for param_name in ['weight', 'bias']:
|
||||||
|
key = f'{prefix}swiglu.w3.{param_name}'
|
||||||
|
if key in packed_state_dict:
|
||||||
|
destination[f'{prefix}down_proj.{param_name}'] = packed_state_dict[key]
|
||||||
|
|
||||||
|
# Copy any other parameters
|
||||||
|
excluded_prefixes = [f'{prefix}swiglu.w12.', f'{prefix}swiglu.w3.']
|
||||||
|
for key, value in packed_state_dict.items():
|
||||||
|
if not any(key.startswith(excluded_prefix) for excluded_prefix in excluded_prefixes) and key not in destination:
|
||||||
|
destination[key] = value
|
||||||
|
|
||||||
|
return destination
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.swiglu(x)
|
||||||
|
|||||||
Reference in New Issue
Block a user