diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3dd5fc21b..a28234135 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -644,10 +644,16 @@ class ModelLoader: def flash_attn_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) + if "softmax_scale" in kwargs and len(args) >= 4: + # if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop + args = (*args[:3],) + args[4:] return flash_attn_func_v3(*args, **kwargs)[0] def flash_attn_varlen_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) + if "softmax_scale" in kwargs and len(args) >= 4: + # if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop + args = (*args[:3],) + args[4:] return flash_attn_varlen_func_v3(*args, **kwargs)[0] transformers.modeling_flash_attention_utils.flash_attn_func = ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 26a4fe043..29b7f6f72 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -233,7 +233,7 @@ class AxolotlInputConfig( flash_attn_fuse_qkv: bool | None = None flash_attn_fuse_mlp: bool | None = None flash_optimum: bool | None = None - use_flash_attention_3: Literal["auto"] | bool | None = "auto" + use_flash_attention_3: Literal["auto"] | bool | None = None eager_attention: bool | None = None