check for fa first

This commit is contained in:
Wing Lian
2025-05-18 07:04:48 -07:00
parent 8c4bc59bfc
commit b22150751f

View File

@@ -435,10 +435,19 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_training_step = Trainer.training_step
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
original_fa_func = None
original_fa_varlen_func = None
if (
importlib.util.find_spec("flash_attn")
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
and hasattr(
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
)
):
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
@@ -449,10 +458,11 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop
)
Trainer.training_step = original_trainer_training_step
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
if original_fa_func:
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [