[tests] reset known modules that are patched on each test function end (#2147)

* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
This commit is contained in:
Wing Lian
2024-12-07 17:24:46 -05:00
committed by GitHub
parent 743ba62bd5
commit 5bef19064b
3 changed files with 41 additions and 4 deletions

View File

@@ -3,14 +3,14 @@ fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
@@ -145,7 +145,7 @@ def patch_training_step_for_ga():
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True)
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
@@ -201,7 +201,7 @@ def patch_forward_for_ga():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True)
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]:
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
@@ -134,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821

View File

@@ -2,7 +2,9 @@
shared pytest fixtures
"""
import functools
import importlib
import shutil
import sys
import tempfile
import time
@@ -113,3 +115,30 @@ def temp_dir():
yield _temp_dir
# Clean up the directory after the test
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
original_fa2_forward = LlamaFlashAttention2.forward
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# Reset other known monkeypatches
modules_to_reset: list[tuple[str, list[str]]] = [
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
("transformers.trainer",),
("transformers.loss.loss_utils",),
]
for module_name_tuple in modules_to_reset:
module_name = module_name_tuple[0]
module = importlib.import_module(module_name)
sys.modules[module_name] = module
importlib.reload(sys.modules[module_name])
if len(module_name_tuple) > 1:
module_globals = module_name_tuple[1]
for module_global in module_globals:
globals().pop(module_global, None)