wip state dict compatible fused mlp

This commit is contained in:
Wing Lian
2025-07-06 14:32:29 -04:00
parent 9a8073e73d
commit 5a063f5c75

View File

@@ -1,8 +1,11 @@
""" """
Fused MLP layer for incrementally improved training efficiency Fused MLP layer for incrementally improved training efficiency
""" """
from collections import OrderedDict
import torch import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.llama.modeling_llama import LlamaMLP from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU from xformers.ops import SwiGLU
@@ -50,3 +53,129 @@ class FusedMLP(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x) return self.swiglu(x)
class FusedMLPv2(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.swiglu = SwiGLU(
in_features=self.hidden_size,
hidden_features=self.intermediate_size,
bias=config.mlp_bias,
_pack_weights=True,
)
assert config.hidden_act == "silu"
def _convert_unpacked_to_packed_state_dict(self, unpacked_state_dict):
"""
Convert state dict from unpacked format (w1, w2, w3) to packed format (w13, w2).
"""
packed_state_dict = OrderedDict()
# Handle w1 and w3 -> w13 conversion for weights
if 'gate_proj.weight' in unpacked_state_dict and 'up_proj.weight' in unpacked_state_dict:
gate_proj_weight = unpacked_state_dict['gate_proj.weight']
up_proj_weight = unpacked_state_dict['up_proj.weight']
# Concatenate gate and up weights along output dimension (dim=0)
packed_state_dict['swiglu.w12.weight'] = torch.cat([gate_proj_weight, up_proj_weight], dim=0)
# Handle w1 and w3 -> w13 conversion for biases (if they exist)
if 'gate_proj.bias' in unpacked_state_dict and 'up_proj.bias' in unpacked_state_dict:
gate_proj_bias = unpacked_state_dict['gate_proj.bias']
up_proj_bias = unpacked_state_dict['up_proj.bias']
# Concatenate gate and up biases along dimension 0
packed_state_dict['swiglu.w12.bias'] = torch.cat([gate_proj_bias, up_proj_bias], dim=0)
# Copy down parameters as-is
if "down_proj.weight" in unpacked_state_dict:
packed_state_dict["swiglu.w3.weight"] = unpacked_state_dict['down_proj.weight']
if "down_proj.bias" in unpacked_state_dict:
packed_state_dict["swiglu.w3.bias"] = unpacked_state_dict['down_proj.bias']
for key in ['swiglu.w3.weight', 'swiglu.w3.bias']:
if key in unpacked_state_dict:
packed_state_dict[key] = unpacked_state_dict[key]
# Copy any other parameters that might exist
excluded_keys = [
'gate_proj.weight', 'gate_proj.bias',
'down_proj.weight', 'down_proj.bias',
'up_proj.weight', 'up_proj.bias',
'swiglu.w12.weight', 'swiglu.w12.bias',
'swiglu.w3.weight', 'swiglu.w3.bias',
]
for key, value in unpacked_state_dict.items():
if key not in excluded_keys:
packed_state_dict[key] = value
return packed_state_dict
def load_state_dict(self, state_dict, strict=True):
"""
Load state dict, handling both packed (w13) and unpacked (w1, w3) formats.
"""
# Check if this is an unpacked state dict (has w1 and w3 instead of w13)
has_unpacked_gate_up = 'gate_proj.weight' in state_dict and 'up_proj.weight' in state_dict
has_packed_swiglu = 'swiglu.w12.weight' in state_dict
if has_unpacked_gate_up and not has_packed_swiglu:
state_dict = self._convert_unpacked_to_packed_state_dict(state_dict)
return super().load_state_dict(state_dict, strict=strict)
def state_dict(self, destination=None, prefix='', keep_vars=False, packed=False):
"""
Return state dict in unpacked format by default for compatibility.
Set packed=True to get the internal packed format.
"""
if packed:
# Return the actual packed state dict
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
else:
# Return unpacked format for compatibility
return self.get_unpacked_state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def get_unpacked_state_dict(self, destination=None, prefix='', keep_vars=False):
"""
Convert current packed state dict to unpacked format for compatibility.
"""
# Get the actual packed state dict first
packed_state_dict = super().state_dict(destination=None, prefix='', keep_vars=keep_vars)
if destination is None:
destination = OrderedDict()
# Handle w13 -> w1 and w3 conversion for weights
if f'{prefix}swiglu.w12.weight' in packed_state_dict:
w13_weight = packed_state_dict[f'{prefix}swiglu.w12.weight']
hidden_dim = w13_weight.shape[0] // 2
w1_weight, w3_weight = torch.split(w13_weight, hidden_dim, dim=0)
destination[f'{prefix}gate_proj.weight'] = w1_weight if not keep_vars else w1_weight.detach().requires_grad_(w1_weight.requires_grad)
destination[f'{prefix}up_proj.weight'] = w3_weight if not keep_vars else w3_weight.detach().requires_grad_(w3_weight.requires_grad)
# Handle w13 -> w1 and w3 conversion for biases (if they exist)
if f'{prefix}swiglu.w12.bias' in packed_state_dict:
w13_bias = packed_state_dict[f'{prefix}swiglu.w12.bias']
hidden_dim = w13_bias.shape[0] // 2
w1_bias, w3_bias = torch.split(w13_bias, hidden_dim, dim=0)
destination[f'{prefix}gate_proj.bias'] = w1_bias if not keep_vars else w1_bias.detach().requires_grad_(w1_bias.requires_grad)
destination[f'{prefix}up_proj.bias'] = w3_bias if not keep_vars else w3_bias.detach().requires_grad_(w3_bias.requires_grad)
# Copy w2 parameters as-is
for param_name in ['weight', 'bias']:
key = f'{prefix}swiglu.w3.{param_name}'
if key in packed_state_dict:
destination[f'{prefix}down_proj.{param_name}'] = packed_state_dict[key]
# Copy any other parameters
excluded_prefixes = [f'{prefix}swiglu.w12.', f'{prefix}swiglu.w3.']
for key, value in packed_state_dict.items():
if not any(key.startswith(excluded_prefix) for excluded_prefix in excluded_prefixes) and key not in destination:
destination[key] = value
return destination
def forward(self, x):
return self.swiglu(x)