handle args to drop dropout
This commit is contained in:
@@ -644,10 +644,16 @@ class ModelLoader:
|
|||||||
|
|
||||||
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
kwargs.pop("dropout_p", None)
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
return flash_attn_func_v3(*args, **kwargs)[0]
|
return flash_attn_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
kwargs.pop("dropout_p", None)
|
kwargs.pop("dropout_p", None)
|
||||||
|
if "softmax_scale" in kwargs and len(args) >= 4:
|
||||||
|
# if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop
|
||||||
|
args = (*args[:3],) + args[4:]
|
||||||
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
return flash_attn_varlen_func_v3(*args, **kwargs)[0]
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_qkv: bool | None = None
|
flash_attn_fuse_qkv: bool | None = None
|
||||||
flash_attn_fuse_mlp: bool | None = None
|
flash_attn_fuse_mlp: bool | None = None
|
||||||
flash_optimum: bool | None = None
|
flash_optimum: bool | None = None
|
||||||
use_flash_attention_3: Literal["auto"] | bool | None = "auto"
|
use_flash_attention_3: Literal["auto"] | bool | None = None
|
||||||
|
|
||||||
eager_attention: bool | None = None
|
eager_attention: bool | None = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user