From 323a9cb153dacfca4e92f27b5407d32ed0d35f3a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 18 May 2025 08:28:52 -0700 Subject: [PATCH] handle return sig change for fa3 --- src/axolotl/utils/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 73dc17ab2..5f84079da 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -644,11 +644,11 @@ class ModelLoader: def flash_attn_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) - return flash_attn_func_v3(*args, **kwargs) + return flash_attn_func_v3(*args, **kwargs)[0] def flash_attn_varlen_func_v3_wrapper(*args, **kwargs): kwargs.pop("dropout_p", None) - return flash_attn_varlen_func_v3(*args, **kwargs) + return flash_attn_varlen_func_v3(*args, **kwargs)[0] transformers.modeling_flash_attention_utils.flash_attn_func = ( flash_attn_func_v3_wrapper