fa3 doesn't support dropout_p, fix unpatching

This commit is contained in:
Wing Lian
2025-05-18 06:26:08 -07:00
parent a064f1c9b4
commit 8c4bc59bfc
2 changed files with 19 additions and 2 deletions

View File

@@ -642,11 +642,19 @@ class ModelLoader:
flash_attn_varlen_func as flash_attn_varlen_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3,
) )
def flash_attn_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_func_v3(*args, **kwargs)
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
kwargs.pop("dropout_p", None)
return flash_attn_varlen_func_v3(*args, **kwargs)
transformers.modeling_flash_attention_utils.flash_attn_func = ( transformers.modeling_flash_attention_utils.flash_attn_func = (
flash_attn_func_v3 flash_attn_func_v3_wrapper
) )
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = ( transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
flash_attn_varlen_func_v3 flash_attn_varlen_func_v3_wrapper
) )
LOG.info("Switched to Flash Attention v3") LOG.info("Switched to Flash Attention v3")

View File

@@ -421,6 +421,7 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
import transformers.modeling_flash_attention_utils
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention, LlamaAttention,
@@ -434,6 +435,10 @@ def cleanup_monkeypatches():
Trainer._inner_training_loop # pylint: disable=protected-access Trainer._inner_training_loop # pylint: disable=protected-access
) )
original_trainer_training_step = Trainer.training_step original_trainer_training_step = Trainer.training_step
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
original_fa_varlen_func = (
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
)
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
@@ -444,6 +449,10 @@ def cleanup_monkeypatches():
original_trainer_inner_training_loop original_trainer_inner_training_loop
) )
Trainer.training_step = original_trainer_training_step Trainer.training_step = original_trainer_training_step
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
original_fa_varlen_func
)
# Reset other known monkeypatches # Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [ modules_to_reset: list[tuple[str, list[str]]] = [