Mixtral: Replace FeedForward with SwiGLU

This commit is contained in:
Casper
2023-12-10 17:10:04 +01:00
parent 86487c2e96
commit 23103ac5ac
2 changed files with 66 additions and 2 deletions

View File

@@ -47,9 +47,11 @@ from transformers.utils import (
replace_return_docstrings,
)
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from .configuration_moe_mistral import MixtralConfig
from xformers.ops import SwiGLU
if is_flash_attn_2_available():
from flash_attn import (
flash_attn_func,
@@ -68,6 +70,61 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MixtralConfig"
def replace_mixtral_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, FeedForward):
mlp = FusedMLP(
module.config,
module.gate_proj,
module.up_proj,
module.down_proj,
)
set_module_name(model, name, mlp)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = FeedForward(self.config)
new_mlp.w1.weight.data = w1
new_mlp.w2.weight.data = w2
new_mlp.w3.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)

View File

@@ -373,7 +373,10 @@ def load_model(
**model_kwargs,
)
elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import MixtralForCausalLM
from axolotl.models.mixtral import (
MixtralForCausalLM,
replace_mixtral_mlp_with_swiglu
)
model = MixtralForCausalLM.from_pretrained(
base_model,
@@ -381,6 +384,10 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name