From d6f64a36847319832e29a936b64e2e16ade4485c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 18 May 2025 13:11:56 -0700 Subject: [PATCH] handle args to drop dropout --- src/axolotl/utils/models.py | 6 ++++++ src/axolotl/utils/schemas/config.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3dd5fc21b..a28234135 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -644,10 +644,16 @@ class ModelLoader: def flash_attn_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) + if "softmax_scale" in kwargs and len(args) >= 4: + # if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop + args = (*args[:3],) + args[4:] return flash_attn_func_v3(*args, **kwargs)[0] def flash_attn_varlen_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) + if "softmax_scale" in kwargs and len(args) >= 4: + # if softmax_scale is provided, then the 3rd position is dropout_p that we need to drop + args = (*args[:3],) + args[4:] return flash_attn_varlen_func_v3(*args, **kwargs)[0] transformers.modeling_flash_attention_utils.flash_attn_func = ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 26a4fe043..29b7f6f72 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -233,7 +233,7 @@ class AxolotlInputConfig( flash_attn_fuse_qkv: bool | None = None flash_attn_fuse_mlp: bool | None = None flash_optimum: bool | None = None - use_flash_attention_3: Literal["auto"] | bool | None = "auto" + use_flash_attention_3: Literal["auto"] | bool | None = None eager_attention: bool | None = None