From dfd06a0f8845a6314041bbf7b3ffe9ede6ee9294 Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 7 Dec 2023 19:53:20 +0100 Subject: [PATCH] Fix naming of rms_norm for Llama --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index b0e053242..b546793f3 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -75,7 +75,7 @@ def replace_llama_attn_with_flash_attn( if cross_entropy: replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss") if rms_norm: - replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm") + replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm") # Disable the transformation of the attention mask in LlamaModel as the flash attention