fa3 doesn't support dropout_p, fix unpatching
This commit is contained in:
@@ -642,11 +642,19 @@ class ModelLoader:
|
||||
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||
)
|
||||
|
||||
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||
kwargs.pop("dropout_p", None)
|
||||
return flash_attn_func_v3(*args, **kwargs)
|
||||
|
||||
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||
kwargs.pop("dropout_p", None)
|
||||
return flash_attn_varlen_func_v3(*args, **kwargs)
|
||||
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||
flash_attn_func_v3
|
||||
flash_attn_func_v3_wrapper
|
||||
)
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
flash_attn_varlen_func_v3
|
||||
flash_attn_varlen_func_v3_wrapper
|
||||
)
|
||||
LOG.info("Switched to Flash Attention v3")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user