[tests] reset known modules that are patched on each test function end (#2147)

* reset known modules that are patched on each test function end

* fix the llama model module name

* prevent unsloth patching multiple times

* pop classes out of the globals after reset

* fix tuple indexing

* manually workaround for llama fa2
This commit is contained in:
Wing Lian
2024-12-07 17:24:46 -05:00
committed by GitHub
parent 743ba62bd5
commit 5bef19064b
3 changed files with 41 additions and 4 deletions

View File

@@ -3,14 +3,14 @@ fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from accelerate.logging import get_logger
from transformers import LlamaForCausalLM
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = get_logger("axolotl.monkeypatch.trainer_grad_accum")
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
@@ -145,7 +145,7 @@ def patch_training_step_for_ga():
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step", main_process_only=True)
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
@@ -201,7 +201,7 @@ def patch_forward_for_ga():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward", main_process_only=True)
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -102,7 +102,14 @@ def detab_code(code: str) -> Tuple[str, str]:
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name
def patch_self_attn_lora():
global self_attn_lora_patched # pylint: disable=global-statement
if self_attn_lora_patched:
# prevent patching multiple times
return
self_attn_forward = get_self_attn_code()
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
self_attn_forward
@@ -134,6 +141,7 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821