Enable replacing xentropy, rmsnorm for Mistral

This commit is contained in:
Casper
2023-12-07 19:52:40 +01:00
parent bf289123e9
commit 40d231a91b
2 changed files with 17 additions and 1 deletions

View File

@@ -24,11 +24,19 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.flash_modules import (
flashattn_forward,
replace_cross_entropy,
replace_rms_norm
)
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
def replace_mistral_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
@@ -45,6 +53,10 @@ def replace_mistral_attn_with_flash_attn(
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
mistral_model_forward
)
if cross_entropy:
replace_cross_entropy(transformers.mistral.llama.modeling_mistral, "CrossEntropyLoss")
if rms_norm:
replace_rms_norm(transformers.mistral.llama.modeling_mistral, "MistralRMSNorm")
@torch.jit.script

View File

@@ -193,7 +193,11 @@ def load_model(
)
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
replace_mistral_attn_with_flash_attn(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
)
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (