Transformers version flexibility and FSDP optimizer patch (#2155)
* allow flexibility in transformers version for FSDP * more flexibility with dev versions of 4.47.0.dev0 * add patch for fsdp * fix typo * correct fn name * stray character * fix patch * reset Trainer too * also reset Trainer.training_step * allow tests/patched to run more than one process on e2e runner * skip tests/patched in e2e for now since it's run in regular pytest
This commit is contained in:
@@ -119,18 +119,27 @@ def temp_dir():
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||
("transformers.trainer",),
|
||||
("transformers.trainer", ["Trainer"]),
|
||||
("transformers.loss.loss_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
|
||||
Reference in New Issue
Block a user