From 40d231a91b9712c7289fa89daa3a31d66b4c7db3 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 7 Dec 2023 19:52:40 +0100 Subject: [PATCH] Enable replacing xentropy, rmsnorm for Mistral --- src/axolotl/monkeypatch/mistral_attn_hijack_flash.py | 12 ++++++++++++ src/axolotl/utils/models.py | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py index 1bf9851d7..d79a4d451 100644 --- a/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py @@ -24,11 +24,19 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids +from axolotl.monkeypatch.flash_modules import ( + flashattn_forward, + replace_cross_entropy, + replace_rms_norm +) + LOG = logging.getLogger("axolotl.monkeypatch.mistral") def replace_mistral_attn_with_flash_attn( packed: Optional[bool] = False, + cross_entropy: Optional[bool] = False, + rms_norm: Optional[bool] = False, ): transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask @@ -45,6 +53,10 @@ def replace_mistral_attn_with_flash_attn( transformers.models.mistral.modeling_mistral.MistralModel.forward = ( mistral_model_forward ) + if cross_entropy: + replace_cross_entropy(transformers.mistral.llama.modeling_mistral, "CrossEntropyLoss") + if rms_norm: + replace_rms_norm(transformers.mistral.llama.modeling_mistral, "MistralRMSNorm") @torch.jit.script diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index acc6f41fa..6acc33efc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -193,7 +193,11 @@ def load_model( ) LOG.info("patching with flash attention") - replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) + replace_mistral_attn_with_flash_attn( + packed=cfg.sample_packing, + cross_entropy=cfg.flash_attn_cross_entropy, + rms_norm=cfg.flash_attn_rms_norm, + ) if cfg.is_llama_derived_model and cfg.xpos_rope: from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (