From 1cb79770260f0ec70c97c5e891a347d55ebc74b0 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 7 Dec 2023 19:59:29 +0100 Subject: [PATCH] Implement Mistral SwiGLU --- src/axolotl/monkeypatch/mistral_attn_hijack_flash.py | 12 +++++++++++- src/axolotl/utils/models.py | 9 +++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index d79a4d451..16a4e557d 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -19,19 +19,29 @@ from transformers.models.mistral.modeling_mistral import ( ) from transformers.models.mistral.modeling_mistral import ( MistralDecoderLayer as OriginalMistralDecoderLayer, + MistralMLP ) from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv -from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.flash_modules import ( flashattn_forward, replace_cross_entropy, replace_rms_norm ) +from axolotl.monkeypatch.fused_modules import FusedMLP LOG = logging.getLogger("axolotl.monkeypatch.mistral") +def replace_mistral_mlp_with_swiglu(model): + for name, module in model.named_modules(): + if isinstance(module, MistralMLP): + mlp = FusedMLP( + module.config, module.gate_proj, module.up_proj, module.down_proj + ) + set_module_name(model, name, mlp) + def replace_mistral_attn_with_flash_attn( packed: Optional[bool] = False, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6acc33efc..2b547d9de 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -278,6 +278,15 @@ def load_model( if cfg.flash_attn_fuse_qkv: LOG.info("patching with fused QKV") replace_llama_qkv_with_fused(model) + elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq: + if cfg.flash_attention and not inference: + from axolotl.monkeypatch.mistral_attn_hijack_flash import ( + replace_mistral_mlp_with_swiglu, + ) + + if cfg.flash_attn_fuse_mlp: + LOG.info("patching with SwiGLU") + replace_mistral_mlp_with_swiglu(model) # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # This is a WIP, still an issue with the backward pass # RuntimeError: grad can be implicitly created only for scalar outputs