From b22150751f5f07586770f2222c471c9f609a36a6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 18 May 2025 07:04:48 -0700 Subject: [PATCH] check for fa first --- tests/conftest.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2bd6d331d..8c25a3606 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -435,10 +435,19 @@ def cleanup_monkeypatches(): Trainer._inner_training_loop # pylint: disable=protected-access ) original_trainer_training_step = Trainer.training_step - original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func - original_fa_varlen_func = ( - transformers.modeling_flash_attention_utils.flash_attn_varlen_func - ) + original_fa_func = None + original_fa_varlen_func = None + if ( + importlib.util.find_spec("flash_attn") + and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func") + and hasattr( + transformers.modeling_flash_attention_utils, "flash_attn_varlen_func" + ) + ): + original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func + original_fa_varlen_func = ( + transformers.modeling_flash_attention_utils.flash_attn_varlen_func + ) # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward @@ -449,10 +458,11 @@ def cleanup_monkeypatches(): original_trainer_inner_training_loop ) Trainer.training_step = original_trainer_training_step - transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func - transformers.modeling_flash_attention_utils.flash_attn_varlen_func = ( - original_fa_varlen_func - ) + if original_fa_func: + transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func + transformers.modeling_flash_attention_utils.flash_attn_varlen_func = ( + original_fa_varlen_func + ) # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [