handle return sig change for fa3

This commit is contained in:
Wing Lian
2025-05-18 08:28:52 -07:00
parent b22150751f
commit 323a9cb153

View File

@@ -644,11 +644,11 @@ class ModelLoader:
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_func_v3(*args, **kwargs)
return flash_attn_func_v3(*args, **kwargs)[0]
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_varlen_func_v3(*args, **kwargs)
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3_wrapper