* patch transformers to allow CP + FA2

* nits

* only patch in CP > 1 case
This commit is contained in:
Dan Saunders
2025-09-25 12:03:50 -04:00
committed by GitHub
parent 33975ce4bc
commit f9748c4dc5
3 changed files with 141 additions and 0 deletions

View File

@@ -84,6 +84,13 @@ class PatchManager:
patch_evaluation_loop()
patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1:
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model)

View File

@@ -0,0 +1,68 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -0,0 +1,66 @@
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
GUARD_PATTERN,
PATCHED_GUARD,
patch_prepare_context_parallel_inputs,
)
@pytest.fixture
def restore_trainer_prepare_method():
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
original_method = getattr(
Trainer,
"_original_prepare_context_parallel_inputs",
Trainer._prepare_context_parallel_inputs,
)
patched_attr_present = hasattr(
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
)
yield
Trainer._prepare_context_parallel_inputs = original_method
if patched_attr_present:
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
Trainer._prepare_context_parallel_inputs = (
Trainer._original_prepare_context_parallel_inputs
)
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
patch_prepare_context_parallel_inputs()
patched_method = Trainer._prepare_context_parallel_inputs
assert patched_method is not None
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source
def test_patch_is_idempotent(restore_trainer_prepare_method):
"""Calling the patch twice should leave the same patched function in place."""
patch_prepare_context_parallel_inputs()
first_patched = Trainer._prepare_context_parallel_inputs
patch_prepare_context_parallel_inputs()
second_patched = Trainer._prepare_context_parallel_inputs
assert first_patched is second_patched