diff --git a/src/axolotl/models/mixtral/modeling_moe_mistral.py b/src/axolotl/models/mixtral/modeling_moe_mistral.py index 6f1fb7a4a..6e1dc5b24 100644 --- a/src/axolotl/models/mixtral/modeling_moe_mistral.py +++ b/src/axolotl/models/mixtral/modeling_moe_mistral.py @@ -47,9 +47,11 @@ from transformers.utils import ( replace_return_docstrings, ) -from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids +from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from .configuration_moe_mistral import MixtralConfig +from xformers.ops import SwiGLU + if is_flash_attn_2_available(): from flash_attn import ( flash_attn_func, @@ -68,6 +70,61 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MixtralConfig" +def replace_mixtral_mlp_with_swiglu(model): + for name, module in model.named_modules(): + if isinstance(module, FeedForward): + mlp = FusedMLP( + module.config, + module.gate_proj, + module.up_proj, + module.down_proj, + ) + set_module_name(model, name, mlp) + + +class FusedMLP(torch.nn.Module): + """ + Fused MLP layer for incrementally improved training efficiency + """ + + def __init__( + self, + config, + gate_proj: torch.nn.Linear, + up_proj: torch.nn.Linear, + down_proj: torch.nn.Linear, + ): + super().__init__() + self.config = config + self.swiglu = SwiGLU( + in_features=config.hidden_size, + hidden_features=config.intermediate_size, + bias=False, + _pack_weights=True, + ) + # overwrite initialized weights with pretrained weights + self.swiglu.w12.weight.data = torch.cat( + (gate_proj.weight.data, up_proj.weight.data), dim=0 + ) + self.swiglu.w3.weight.data = down_proj.weight.data + + def _post_training(self, model, name): + w1, w2 = torch.split( # pylint: disable=invalid-name + self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0 + ) + + # Assign the split weights back to the original layers + new_mlp = FeedForward(self.config) + new_mlp.w1.weight.data = w1 + new_mlp.w2.weight.data = w2 + new_mlp.w3.weight.data = self.swiglu.w3.weight.data + + set_module_name(model, name, new_mlp) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + return self.swiglu(x) + + # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8531cb251..84a4f6cf1 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -373,7 +373,10 @@ def load_model( **model_kwargs, ) elif model_type == "MixtralForCausalLM": - from axolotl.models.mixtral import MixtralForCausalLM + from axolotl.models.mixtral import ( + MixtralForCausalLM, + replace_mixtral_mlp_with_swiglu + ) model = MixtralForCausalLM.from_pretrained( base_model, @@ -381,6 +384,10 @@ def load_model( load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, **model_kwargs, ) + + LOG.info("Mixtral MoE: Replacing experts with SwiGLU") + replace_mixtral_mlp_with_swiglu(model) + elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name