Fix naming of rms_norm for Llama
This commit is contained in:
@@ -75,7 +75,7 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
replace_cross_entropy(transformers.models.llama.modeling_llama, "CrossEntropyLoss")
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
replace_rms_norm(transformers.models.llama.modeling_llama, "RMSNorm")
|
replace_rms_norm(transformers.models.llama.modeling_llama, "LlamaRMSNorm")
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
|
|||||||
Reference in New Issue
Block a user