diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index bddd388e4..06f349f83 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -133,13 +133,6 @@ class PatchManager: patch_evaluation_loop() patch_maybe_log_save_evaluate() - if self.cfg.context_parallel_size > 1: - from axolotl.monkeypatch.transformers.trainer_context_parallel import ( - patch_prepare_context_parallel_inputs, - ) - - patch_prepare_context_parallel_inputs() - def apply_post_model_build_patches(self, model: PreTrainedModel): """Apply patches right after model build, before post-load setup.""" self._finalize_moe_expert_quantization(model) diff --git a/src/axolotl/monkeypatch/accelerate/parallelism_config.py b/src/axolotl/monkeypatch/accelerate/parallelism_config.py index ebd9a6f0d..56636d697 100644 --- a/src/axolotl/monkeypatch/accelerate/parallelism_config.py +++ b/src/axolotl/monkeypatch/accelerate/parallelism_config.py @@ -81,6 +81,7 @@ def patch_prepare_cp(): import contextlib from accelerate import Accelerator + from transformers import Trainer def patched_prepare_cp(self, *args): if self.parallelism_config.cp_backend == "deepspeed": @@ -95,4 +96,11 @@ def patch_prepare_cp(): self._cp_context = _noop_cp_context return args + def _noop_prepare_context_parallel_inputs(self, model, inputs): + return contextlib.nullcontext, inputs + + # prevent double CP partition Accelerator._prepare_cp = patched_prepare_cp + + # remove unneeded calculation upstream + Trainer._prepare_context_parallel_inputs = _noop_prepare_context_parallel_inputs diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py deleted file mode 100644 index 15f90423e..000000000 --- a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer.""" - -from __future__ import annotations - -import importlib -import inspect - -from transformers import Trainer - -from axolotl.monkeypatch.utils import detab_code -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - -GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' -PATCHED_GUARD = 'if (attn_impl := (getattr(model.config, "_attn_implementation", None) or getattr(model.model.config, "_attn_implementation", None))) and attn_impl not in ("sdpa", "flash_attention_2"):' - - -def patch_prepare_context_parallel_inputs() -> None: - """Relax the SDPA-only guard when running context parallelism with FlashAttention.""" - if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False): - LOG.debug("Trainer._prepare_context_parallel_inputs already patched") - return - - try: - original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs) - except OSError as exc: # pragma: no cover - occurs when source is unavailable - LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc) - return - - if GUARD_PATTERN not in original_source: - LOG.warning( - "Expected guard not found in Trainer._prepare_context_parallel_inputs; \n" - "skipping FlashAttention context parallelism patch" - ) - return - - patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD) - patched_source, _ = detab_code(patched_source) - patched_source = patched_source.replace( - "def _prepare_context_parallel_inputs(", - "def axolotl_prepare_context_parallel_inputs(", - 1, - ) - - module_name = Trainer.__module__ - module = importlib.import_module(module_name) - - # import symbols referenced in the method so exec can succeed - items_to_import = [] - for item in dir(module): - if item in patched_source: - items_to_import.append(item) - - # Use a separate namespace to capture the exec'd function - namespace = {} - exec(f"from {module_name} import ({', '.join(items_to_import)})", namespace) - exec(patched_source, namespace) - - # Explicitly get the function from the namespace - axolotl_prepare_context_parallel_inputs = namespace[ - "axolotl_prepare_context_parallel_inputs" - ] - Trainer._original_prepare_context_parallel_inputs = ( - Trainer._prepare_context_parallel_inputs - ) - Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs - Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source - Trainer._axolotl_prepare_context_parallel_inputs_patched = True - LOG.debug( - "Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP" - ) diff --git a/tests/monkeypatch/test_trainer_context_parallel_patch.py b/tests/monkeypatch/test_trainer_context_parallel_patch.py deleted file mode 100644 index 84c883e91..000000000 --- a/tests/monkeypatch/test_trainer_context_parallel_patch.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Tests for the HF Trainer context parallel patch.""" - -import pytest -from transformers import Trainer - -from axolotl.monkeypatch.transformers.trainer_context_parallel import ( - GUARD_PATTERN, - PATCHED_GUARD, - patch_prepare_context_parallel_inputs, -) - - -@pytest.fixture -def restore_trainer_prepare_method(): - """Ensure Trainer._prepare_context_parallel_inputs is restored after a test.""" - original_method = getattr( - Trainer, - "_original_prepare_context_parallel_inputs", - Trainer._prepare_context_parallel_inputs, - ) - patched_attr_present = hasattr( - Trainer, "_axolotl_prepare_context_parallel_inputs_patched" - ) - - yield - - Trainer._prepare_context_parallel_inputs = original_method - if patched_attr_present: - delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") - if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): - delattr(Trainer, "_original_prepare_context_parallel_inputs") - if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"): - delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source") - - -def test_patch_attention_guard(restore_trainer_prepare_method): - """Patch should swap the guard to allow sdpa or flash attention.""" - # Ensure we start from the unpatched method - if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): - Trainer._prepare_context_parallel_inputs = ( - Trainer._original_prepare_context_parallel_inputs - ) - delattr(Trainer, "_original_prepare_context_parallel_inputs") - if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"): - delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") - - patch_prepare_context_parallel_inputs() - - patched_method = Trainer._prepare_context_parallel_inputs - assert patched_method is not None - assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False) - - source = Trainer._axolotl_prepare_context_parallel_inputs_source - assert GUARD_PATTERN not in source - assert PATCHED_GUARD in source - - -def test_patch_is_idempotent(restore_trainer_prepare_method): - """Calling the patch twice should leave the same patched function in place.""" - patch_prepare_context_parallel_inputs() - first_patched = Trainer._prepare_context_parallel_inputs - - patch_prepare_context_parallel_inputs() - second_patched = Trainer._prepare_context_parallel_inputs - - assert first_patched is second_patched