[tests] reset known modules that are patched on each test function end (#2147)
* reset known modules that are patched on each test function end * fix the llama model module name * prevent unsloth patching multiple times * pop classes out of the globals after reset * fix tuple indexing * manually workaround for llama fa2
This commit is contained in:
@@ -3,14 +3,14 @@ fix for FSDP gradient accumulation
|
||||
see https://github.com/huggingface/transformers/pull/35128
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
with self.compute_loss_context_manager():
|
||||
@@ -145,7 +145,7 @@ def patch_training_step_for_ga():
|
||||
globals(),
|
||||
)
|
||||
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching training_step", main_process_only=True)
|
||||
LOG.info("patching training_step")
|
||||
Trainer.training_step = ( # pylint: disable=protected-access
|
||||
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -201,7 +201,7 @@ def patch_forward_for_ga():
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching forward", main_process_only=True)
|
||||
LOG.info("patching forward")
|
||||
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
@@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]:
|
||||
return code, spaces
|
||||
|
||||
|
||||
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def patch_self_attn_lora():
|
||||
global self_attn_lora_patched # pylint: disable=global-statement
|
||||
if self_attn_lora_patched:
|
||||
# prevent patching multiple times
|
||||
return
|
||||
self_attn_forward = get_self_attn_code()
|
||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||
self_attn_forward
|
||||
@@ -134,6 +141,7 @@ def patch_self_attn_lora():
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
self_attn_lora_patched = True
|
||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||
LlamaFlashAttention2.forward = (
|
||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
shared pytest fixtures
|
||||
"""
|
||||
import functools
|
||||
import importlib
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
@@ -113,3 +115,30 @@ def temp_dir():
|
||||
yield _temp_dir
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
|
||||
# Reset other known monkeypatches
|
||||
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||
("transformers.trainer",),
|
||||
("transformers.loss.loss_utils",),
|
||||
]
|
||||
for module_name_tuple in modules_to_reset:
|
||||
module_name = module_name_tuple[0]
|
||||
module = importlib.import_module(module_name)
|
||||
sys.modules[module_name] = module
|
||||
importlib.reload(sys.modules[module_name])
|
||||
if len(module_name_tuple) > 1:
|
||||
module_globals = module_name_tuple[1]
|
||||
for module_global in module_globals:
|
||||
globals().pop(module_global, None)
|
||||
|
||||
Reference in New Issue
Block a user