Fix naming of rms_norm for Llama
This commit is contained in:
@@ -75,7 +75,7 @@ def replace_llama_attn_with_flash_attn(
|
||||
if cross_entropy:
|
||||
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
||||
if rms_norm:
|
||||
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
|
||||
replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm")
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
|
||||
Reference in New Issue
Block a user