handle return sig change for fa3
This commit is contained in:
@@ -644,11 +644,11 @@ class ModelLoader:
|
|||||||
|
|
||||||
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
kwargs.pop("dropout_p", None)
|
kwargs.pop("dropout_p", None)
|
||||||
return flash_attn_func_v3(*args, **kwargs)
|
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
kwargs.pop("dropout_p", None)
|
kwargs.pop("dropout_p", None)
|
||||||
return flash_attn_varlen_func_v3(*args, **kwargs)
|
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
flash_attn_func_v3_wrapper
|
flash_attn_func_v3_wrapper
|
||||||
|
|||||||
Reference in New Issue
Block a user