diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 73dc17ab2..5f84079da 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -644,11 +644,11 @@ class ModelLoader: def flash_attn_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) - return flash_attn_func_v3(*args, **kwargs) + return flash_attn_func_v3(*args, **kwargs)[0] def flash_attn_varlen_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) - return flash_attn_varlen_func_v3(*args, **kwargs) + return flash_attn_varlen_func_v3(*args, **kwargs)[0] transformers.modeling_flash_attention_utils.flash_attn_func = ( flash_attn_func_v3_wrapper