Fix naming of rms_norm for Llama

This commit is contained in:
Casper
2023-12-07 19:53:20 +01:00
parent 40d231a91b
commit dfd06a0f88

View File

@@ -75,7 +75,7 @@ def replace_llama_attn_with_flash_attn(
if cross_entropy:
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
if rms_norm:
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm")
# Disable the transformation of the attention mask in LlamaModel as the flash attention