From 8c4bc59bfc687ccb0b4fbeba3166b506e6245687 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 18 May 2025 06:26:08 -0700 Subject: [PATCH] fa3 doesn't support dropout_p, fix unpatching --- src/axolotl/utils/models.py | 12 ++++++++++-- tests/conftest.py | 9 +++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index f14818e39..73dc17ab2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -642,11 +642,19 @@ class ModelLoader: flash_attn_varlen_func as flash_attn_varlen_func_v3, ) + def flash_attn_func_v3_wrapper(*args, **kwargs): + kwargs.pop("dropout_p", None) + return flash_attn_func_v3(*args, **kwargs) + + def flash_attn_varlen_func_v3_wrapper(*args, **kwargs): + kwargs.pop("dropout_p", None) + return flash_attn_varlen_func_v3(*args, **kwargs) + transformers.modeling_flash_attention_utils.flash_attn_func = ( - flash_attn_func_v3 + flash_attn_func_v3_wrapper ) transformers.modeling_flash_attention_utils.flash_attn_varlen_func = ( - flash_attn_varlen_func_v3 + flash_attn_varlen_func_v3_wrapper ) LOG.info("Switched to Flash Attention v3") diff --git a/tests/conftest.py b/tests/conftest.py index f1df17e2b..2bd6d331d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -421,6 +421,7 @@ def temp_dir(): @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): + import transformers.modeling_flash_attention_utils from transformers import Trainer from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, LlamaAttention, @@ -434,6 +435,10 @@ def cleanup_monkeypatches(): Trainer._inner_training_loop # pylint: disable=protected-access ) original_trainer_training_step = Trainer.training_step + original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func + original_fa_varlen_func = ( + transformers.modeling_flash_attention_utils.flash_attn_varlen_func + ) # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward @@ -444,6 +449,10 @@ def cleanup_monkeypatches(): original_trainer_inner_training_loop ) Trainer.training_step = original_trainer_training_step + transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func + transformers.modeling_flash_attention_utils.flash_attn_varlen_func = ( + original_fa_varlen_func + ) # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [