diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index b0e053242..b546793f3 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -75,7 +75,7 @@ def replace_llama_attn_with_flash_attn( if cross_entropy: replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss") if rms_norm: - replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm") + replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm") # Disable the transformation of the attention mask in LlamaModel as the flash attention