Compare commits
5 Commits
dump-confi
...
mixtral_sw
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a58a9e5f6c | ||
|
|
279a1401b5 | ||
|
|
083beb6425 | ||
|
|
2ac1a72e4b | ||
|
|
23103ac5ac |
@@ -3,4 +3,7 @@ Custom modeling code for mixtral
|
||||
"""
|
||||
|
||||
from .configuration_moe_mistral import MixtralConfig # noqa
|
||||
from .modeling_moe_mistral import MixtralForCausalLM # noqa
|
||||
from .modeling_moe_mistral import ( # noqa
|
||||
MixtralForCausalLM,
|
||||
replace_mixtral_mlp_with_swiglu,
|
||||
)
|
||||
|
||||
@@ -46,8 +46,9 @@ from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from .configuration_moe_mistral import MixtralConfig
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
@@ -68,6 +69,61 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "MixtralConfig"
|
||||
|
||||
|
||||
def replace_mixtral_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, FeedForward):
|
||||
mlp = FusedMLP(
|
||||
module.config,
|
||||
module.gate_proj,
|
||||
module.up_proj,
|
||||
module.down_proj,
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = FeedForward(self.config)
|
||||
new_mlp.w1.weight.data = w1
|
||||
new_mlp.w2.weight.data = w2
|
||||
new_mlp.w3.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
|
||||
@@ -373,7 +373,10 @@ def load_model(
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type == "MixtralForCausalLM":
|
||||
from axolotl.models.mixtral import MixtralForCausalLM
|
||||
from axolotl.models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
replace_mixtral_mlp_with_swiglu,
|
||||
)
|
||||
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
@@ -381,6 +384,11 @@ def load_model(
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
||||
replace_mixtral_mlp_with_swiglu(model)
|
||||
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
Reference in New Issue
Block a user