diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d347387cf..b0e053242 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -25,8 +25,11 @@ from transformers.models.llama.modeling_llama import ( from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.fused_modules import FusedAttention, FusedMLP -from axolotl.monkeypatch.flash_modules import flashattn_forward - +from axolotl.monkeypatch.flash_modules import ( + flashattn_forward, + replace_cross_entropy, + replace_rms_norm +) LOG = logging.getLogger("axolotl") @@ -69,6 +72,10 @@ def replace_llama_attn_with_flash_attn( transformers.models.llama.modeling_llama.LlamaModel.forward = ( llama_model_forward ) + if cross_entropy: + replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss") + if rms_norm: + replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm") # Disable the transformation of the attention mask in LlamaModel as the flash attention