diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 3d4b7b96b..1e46f5c34 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -84,6 +84,13 @@ class PatchManager: patch_evaluation_loop() patch_maybe_log_save_evaluate() + if self.cfg.context_parallel_size > 1: + from axolotl.monkeypatch.transformers.trainer_context_parallel import ( + patch_prepare_context_parallel_inputs, + ) + + patch_prepare_context_parallel_inputs() + def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" self._apply_llama_flash_attn_patches(model) diff --git a/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py new file mode 100644 index 000000000..74a35e83f --- /dev/null +++ b/src/axolotl/monkeypatch/transformers/trainer_context_parallel.py @@ -0,0 +1,68 @@ +"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer.""" + +from __future__ import annotations + +import importlib +import inspect + +from transformers import Trainer + +from axolotl.monkeypatch.utils import detab_code +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":' +PATCHED_GUARD = ( + 'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):' +) + + +def patch_prepare_context_parallel_inputs() -> None: + """Relax the SDPA-only guard when running context parallelism with FlashAttention.""" + if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False): + LOG.debug("Trainer._prepare_context_parallel_inputs already patched") + return + + try: + original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs) + except OSError as exc: # pragma: no cover - occurs when source is unavailable + LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc) + return + + if GUARD_PATTERN not in original_source: + LOG.warning( + "Expected guard not found in Trainer._prepare_context_parallel_inputs; \n" + "skipping FlashAttention context parallelism patch" + ) + return + + patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD) + patched_source, _ = detab_code(patched_source) + patched_source = patched_source.replace( + "def _prepare_context_parallel_inputs(", + "def axolotl_prepare_context_parallel_inputs(", + 1, + ) + + module_name = Trainer.__module__ + module = importlib.import_module(module_name) + + # import symbols referenced in the method so exec can succeed + items_to_import = [] + for item in dir(module): + if item in patched_source: + items_to_import.append(item) + + exec(f"from {module_name} import ({', '.join(items_to_import)})", globals()) + exec(patched_source, globals()) + + Trainer._original_prepare_context_parallel_inputs = ( + Trainer._prepare_context_parallel_inputs + ) + Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs + Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source + Trainer._axolotl_prepare_context_parallel_inputs_patched = True + LOG.debug( + "Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP" + ) diff --git a/tests/monkeypatch/test_trainer_context_parallel_patch.py b/tests/monkeypatch/test_trainer_context_parallel_patch.py new file mode 100644 index 000000000..84c883e91 --- /dev/null +++ b/tests/monkeypatch/test_trainer_context_parallel_patch.py @@ -0,0 +1,66 @@ +"""Tests for the HF Trainer context parallel patch.""" + +import pytest +from transformers import Trainer + +from axolotl.monkeypatch.transformers.trainer_context_parallel import ( + GUARD_PATTERN, + PATCHED_GUARD, + patch_prepare_context_parallel_inputs, +) + + +@pytest.fixture +def restore_trainer_prepare_method(): + """Ensure Trainer._prepare_context_parallel_inputs is restored after a test.""" + original_method = getattr( + Trainer, + "_original_prepare_context_parallel_inputs", + Trainer._prepare_context_parallel_inputs, + ) + patched_attr_present = hasattr( + Trainer, "_axolotl_prepare_context_parallel_inputs_patched" + ) + + yield + + Trainer._prepare_context_parallel_inputs = original_method + if patched_attr_present: + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") + if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): + delattr(Trainer, "_original_prepare_context_parallel_inputs") + if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"): + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source") + + +def test_patch_attention_guard(restore_trainer_prepare_method): + """Patch should swap the guard to allow sdpa or flash attention.""" + # Ensure we start from the unpatched method + if hasattr(Trainer, "_original_prepare_context_parallel_inputs"): + Trainer._prepare_context_parallel_inputs = ( + Trainer._original_prepare_context_parallel_inputs + ) + delattr(Trainer, "_original_prepare_context_parallel_inputs") + if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"): + delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched") + + patch_prepare_context_parallel_inputs() + + patched_method = Trainer._prepare_context_parallel_inputs + assert patched_method is not None + assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False) + + source = Trainer._axolotl_prepare_context_parallel_inputs_source + assert GUARD_PATTERN not in source + assert PATCHED_GUARD in source + + +def test_patch_is_idempotent(restore_trainer_prepare_method): + """Calling the patch twice should leave the same patched function in place.""" + patch_prepare_context_parallel_inputs() + first_patched = Trainer._prepare_context_parallel_inputs + + patch_prepare_context_parallel_inputs() + second_patched = Trainer._prepare_context_parallel_inputs + + assert first_patched is second_patched