Enable replacing xentropy, rmsnorm for Mistral
This commit is contained in:
@@ -24,11 +24,19 @@ from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, r
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
from axolotl.monkeypatch.flash_modules import (
|
||||
flashattn_forward,
|
||||
replace_cross_entropy,
|
||||
replace_rms_norm
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
@@ -45,6 +53,10 @@ def replace_mistral_attn_with_flash_attn(
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
if cross_entropy:
|
||||
replace_cross_entropy(transformers.mistral.llama.modeling_mistral, "CrossEntropyLoss")
|
||||
if rms_norm:
|
||||
replace_rms_norm(transformers.mistral.llama.modeling_mistral, "MistralRMSNorm")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
|
||||
@@ -193,7 +193,11 @@ def load_model(
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
replace_mistral_attn_with_flash_attn(
|
||||
packed=cfg.sample_packing,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
|
||||
Reference in New Issue
Block a user