check for fa first
This commit is contained in:
@@ -435,10 +435,19 @@ def cleanup_monkeypatches():
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||
original_fa_varlen_func = (
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||
)
|
||||
original_fa_func = None
|
||||
original_fa_varlen_func = None
|
||||
if (
|
||||
importlib.util.find_spec("flash_attn")
|
||||
and hasattr(transformers.modeling_flash_attention_utils, "flash_attn_func")
|
||||
and hasattr(
|
||||
transformers.modeling_flash_attention_utils, "flash_attn_varlen_func"
|
||||
)
|
||||
):
|
||||
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||
original_fa_varlen_func = (
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||
)
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
@@ -449,10 +458,11 @@ def cleanup_monkeypatches():
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
original_fa_varlen_func
|
||||
)
|
||||
if original_fa_func:
|
||||
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||
original_fa_varlen_func
|
||||
)
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
|
||||
Reference in New Issue
Block a user