fa3 doesn't support dropout_p, fix unpatching

This commit is contained in:
Wing Lian
2025-05-18 06:26:08 -07:00
parent a064f1c9b4
commit 8c4bc59bfc
2 changed files with 19 additions and 2 deletions

View File

@@ -642,11 +642,19 @@ class ModelLoader:
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_func_v3(*args, **kwargs)
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_varlen_func_v3(*args, **kwargs)
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3
flash_attn_func_v3_wrapper
)
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3
flash_attn_varlen_func_v3_wrapper
)
LOG.info("Switched to Flash Attention v3")