diff --git a/tests/conftest.py b/tests/conftest.py index a9dde9dd8..030ddd68d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,6 +138,7 @@ def cleanup_monkeypatches(): # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [ + ("transformers",), ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), ("transformers.trainer", ["Trainer"]), ("transformers.loss.loss_utils",),