fa3 doesn't support dropout_p, fix unpatching
This commit is contained in:
@@ -642,11 +642,19 @@ class ModelLoader:
|
|||||||
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
flash_attn_varlen_func as flash_attn_varlen_func_v3,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def flash_attn_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
return flash_attn_func_v3(*args, **kwargs)
|
||||||
|
|
||||||
|
def flash_attn_varlen_func_v3_wrapper(*args, **kwargs):
|
||||||
|
kwargs.pop("dropout_p", None)
|
||||||
|
return flash_attn_varlen_func_v3(*args, **kwargs)
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
transformers.modeling_flash_attention_utils.flash_attn_func = (
|
||||||
flash_attn_func_v3
|
flash_attn_func_v3_wrapper
|
||||||
)
|
)
|
||||||
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
flash_attn_varlen_func_v3
|
flash_attn_varlen_func_v3_wrapper
|
||||||
)
|
)
|
||||||
LOG.info("Switched to Flash Attention v3")
|
LOG.info("Switched to Flash Attention v3")
|
||||||
|
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ def temp_dir():
|
|||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
|
import transformers.modeling_flash_attention_utils
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@@ -434,6 +435,10 @@ def cleanup_monkeypatches():
|
|||||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
original_trainer_training_step = Trainer.training_step
|
original_trainer_training_step = Trainer.training_step
|
||||||
|
original_fa_func = transformers.modeling_flash_attention_utils.flash_attn_func
|
||||||
|
original_fa_varlen_func = (
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func
|
||||||
|
)
|
||||||
# monkey patches can happen inside the tests
|
# monkey patches can happen inside the tests
|
||||||
yield
|
yield
|
||||||
# Reset LlamaFlashAttention2 forward
|
# Reset LlamaFlashAttention2 forward
|
||||||
@@ -444,6 +449,10 @@ def cleanup_monkeypatches():
|
|||||||
original_trainer_inner_training_loop
|
original_trainer_inner_training_loop
|
||||||
)
|
)
|
||||||
Trainer.training_step = original_trainer_training_step
|
Trainer.training_step = original_trainer_training_step
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_func = original_fa_func
|
||||||
|
transformers.modeling_flash_attention_utils.flash_attn_varlen_func = (
|
||||||
|
original_fa_varlen_func
|
||||||
|
)
|
||||||
|
|
||||||
# Reset other known monkeypatches
|
# Reset other known monkeypatches
|
||||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
|
|||||||
Reference in New Issue
Block a user