Implement Mistral SwiGLU

This commit is contained in:
Casper
2023-12-07 19:59:29 +01:00
parent dfd06a0f88
commit 1cb7977026
2 changed files with 20 additions and 1 deletions

View File

@@ -19,19 +19,29 @@ from transformers.models.mistral.modeling_mistral import (
)
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
MistralMLP
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.monkeypatch.flash_modules import (
flashattn_forward,
replace_cross_entropy,
replace_rms_norm
)
from axolotl.monkeypatch.fused_modules import FusedMLP
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
def replace_mistral_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, MistralMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,

View File

@@ -278,6 +278,15 @@ def load_model(
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq:
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_mlp_with_swiglu,
)
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_mistral_mlp_with_swiglu(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs