Compare commits

...

5 Commits

Author SHA1 Message Date
Dan Saunders
09725be990 add support for CP + torch SDPA 2025-09-25 12:03:43 -04:00
Dan Saunders
f9bd6936c1 Merge branch 'main' into cp-fix 2025-09-24 14:01:23 -04:00
Dan Saunders
b9a3bfee5a only patch in CP > 1 case 2025-09-24 13:36:14 -04:00
Dan Saunders
08124a7c92 nits 2025-09-24 13:25:46 -04:00
Dan Saunders
56e0a77e0d patch transformers to allow CP + FA2 2025-09-24 13:08:38 -04:00
10 changed files with 414 additions and 66 deletions

View File

@@ -84,6 +84,15 @@ class PatchManager:
patch_evaluation_loop() patch_evaluation_loop()
patch_maybe_log_save_evaluate() patch_maybe_log_save_evaluate()
if self.cfg.context_parallel_size > 1 and getattr(
self.cfg, "flash_attention", False
):
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
patch_prepare_context_parallel_inputs,
)
patch_prepare_context_parallel_inputs()
def apply_post_model_load_patches(self, model: PreTrainedModel): def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance.""" """Apply patches that require the model instance."""
self._apply_llama_flash_attn_patches(model) self._apply_llama_flash_attn_patches(model)

View File

@@ -13,21 +13,10 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
import transformers.modeling_flash_attention_utils import transformers.modeling_flash_attention_utils as flash_utils
from ring_flash_attn import ring_flash_attn_func from ring_flash_attn import ring_flash_attn_func
from ring_flash_attn.adapters.hf_adapter import check_params from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -118,7 +107,7 @@ def create_flash_attn_forward_varlen_llama3(
# Handle sliding window # Handle sliding window
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window _flash_windows_supported()
and sliding_window is not None and sliding_window is not None
and key_states.shape[1] > sliding_window and key_states.shape[1] > sliding_window
) )
@@ -194,3 +183,18 @@ def substitute_hf_flash_attn(
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
def _flash_windows_supported() -> bool:
"""Return whether current transformers build advertises sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
return True
return bool(support)

View File

@@ -13,18 +13,9 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers.modeling_flash_attention_utils as flash_utils
from torch.distributed import DeviceMesh from torch.distributed import DeviceMesh
try:
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RingAttnFunc from axolotl.utils.schemas.enums import RingAttnFunc
@@ -83,7 +74,7 @@ def create_ring_flash_attention_forward(
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = ( use_sliding_windows = (
_flash_supports_window _flash_windows_supported()
and sliding_window is not None and sliding_window is not None
and key_states.shape[1] > sliding_window and key_states.shape[1] > sliding_window
) )
@@ -225,3 +216,19 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
def _flash_windows_supported() -> bool:
"""Best-effort check for FlashAttention sliding-window support."""
support = getattr(flash_utils, "_flash_supports_window", None)
if support is None:
support = getattr(flash_utils, "_flash_supports_window_size", None)
if support is None:
return True
if callable(support):
# Signature differs across versions; assume support when callable.
return True
return bool(support)

View File

@@ -0,0 +1,68 @@
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
from __future__ import annotations
import importlib
import inspect
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
PATCHED_GUARD = (
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
)
def patch_prepare_context_parallel_inputs() -> None:
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
return
try:
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
except OSError as exc: # pragma: no cover - occurs when source is unavailable
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
return
if GUARD_PATTERN not in original_source:
LOG.warning(
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
"skipping FlashAttention context parallelism patch"
)
return
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
patched_source, _ = detab_code(patched_source)
patched_source = patched_source.replace(
"def _prepare_context_parallel_inputs(",
"def axolotl_prepare_context_parallel_inputs(",
1,
)
module_name = Trainer.__module__
module = importlib.import_module(module_name)
# import symbols referenced in the method so exec can succeed
items_to_import = []
for item in dir(module):
if item in patched_source:
items_to_import.append(item)
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
exec(patched_source, globals())
Trainer._original_prepare_context_parallel_inputs = (
Trainer._prepare_context_parallel_inputs
)
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
LOG.debug(
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
)

View File

@@ -179,7 +179,11 @@ def execute_training(
) )
) )
if cfg.context_parallel_size > 1: use_flash_cp = cfg.context_parallel_size > 1 and bool(
getattr(cfg, "flash_attention", False)
)
if use_flash_cp:
models = [trainer.model] models = [trainer.model]
if hasattr(trainer, "ref_model") and trainer.ref_model: if hasattr(trainer, "ref_model") and trainer.ref_model:
models.append(trainer.ref_model) models.append(trainer.ref_model)

View File

@@ -1,7 +1,6 @@
"""Module with validation methods for config pydantic model.""" """Module with validation methods for config pydantic model."""
import json import json
import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -1314,50 +1313,40 @@ class ComplexValidationMixin:
if not self.context_parallel_size: if not self.context_parallel_size:
self.context_parallel_size = 1 self.context_parallel_size = 1
elif self.context_parallel_size > 1: elif self.context_parallel_size > 1:
if not self.flash_attention: use_flash_attention = getattr(self, "flash_attention", False)
use_sdp_attention = getattr(self, "sdp_attention", False)
if not (use_flash_attention or use_sdp_attention):
raise ValueError( raise ValueError(
"flash_attention: true must be set with context_parallel_size > 1" "context_parallel_size > 1 requires either flash_attention: true "
"or sdp_attention: true"
) )
if self.sample_packing and self.micro_batch_size > 1: if use_flash_attention:
raise ValueError( if self.sample_packing and self.micro_batch_size > 1:
"micro_batch_size must be set to 1 when sample_packing is enabled " raise ValueError(
"due to a `ring-flash-attn` requirement" "micro_batch_size must be set to 1 when sample_packing is enabled "
"due to a `ring-flash-attn` requirement"
)
try:
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
) )
try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window = True
sys.modules[
"transformers.modeling_flash_attention_utils"
]._flash_supports_window_size = True
sys.modules[
"transformers.modeling_flash_attention_utils"
].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal
import ring_flash_attn # noqa: F401 # Required after monkey-patching
except ImportError as exception:
raise ImportError(
"context_parallel_size > 1 but ring_flash_attn is not installed. "
"Please install it with `pip install axolotl[ring-flash-attn] "
"or `pip install ring-flash-attn>=0.1.4`."
) from exception
LOG.warning(
"Sequence parallelism (SP) is enabled with "
f"context_parallel_size={self.context_parallel_size}. "
"Please note that logged losses may differ slightly to the non-SP "
"losses due to transformers Trainer implementation details. "
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
"for more details."
)
return self return self
@model_validator(mode="after") @model_validator(mode="after")

View File

@@ -23,6 +23,8 @@ class TestSequenceParallelism:
pad_to_sequence_len=True, pad_to_sequence_len=True,
ring_attn_func=None, ring_attn_func=None,
threshold=2.0, threshold=2.0,
flash_attention=True,
sdp_attention=False,
): ):
"""Helper method to run sequence parallel tests with different configurations""" """Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault( cfg = DictDefault(
@@ -58,7 +60,8 @@ class TestSequenceParallelism:
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_8bit", "optimizer": "adamw_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": flash_attention,
"sdp_attention": sdp_attention,
"loss_watchdog_threshold": 5.0, "loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3, "loss_watchdog_patience": 3,
"bf16": "auto", "bf16": "auto",
@@ -132,3 +135,16 @@ class TestSequenceParallelism:
ring_attn_func=ring_attn_func, ring_attn_func=ring_attn_func,
threshold=threshold, threshold=threshold,
) )
def test_sequence_parallel_training_sdpa(self, temp_dir):
"""Smoke test for SDPA-based context parallelism."""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=False,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=3.0,
flash_attention=False,
sdp_attention=True,
)

View File

@@ -0,0 +1,74 @@
"""Tests for PatchManager context parallel patch selection."""
import addict
from axolotl.loaders.patch_manager import PatchManager
from axolotl.utils.dict import DictDefault
def _stub_transformers_patches(monkeypatch):
"""Replace trainer loss patchers with no-ops for isolation."""
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
lambda: None,
)
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
lambda: None,
)
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
"""When flash attention is enabled, we patch Trainer for CP."""
_stub_transformers_patches(monkeypatch)
patch_calls = {"count": 0}
def stub_patch():
patch_calls["count"] += 1
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
stub_patch,
)
cfg = DictDefault(
{
"context_parallel_size": 2,
"flash_attention": True,
"sdp_attention": False,
}
)
manager = PatchManager(cfg, addict.Dict())
manager._apply_transformers_patches()
assert patch_calls["count"] == 1
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
"""When only SDPA is requested, we should not patch Trainer."""
_stub_transformers_patches(monkeypatch)
patch_calls = {"count": 0}
def stub_patch():
patch_calls["count"] += 1
monkeypatch.setattr(
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
stub_patch,
)
cfg = DictDefault(
{
"context_parallel_size": 2,
"flash_attention": False,
"sdp_attention": True,
}
)
manager = PatchManager(cfg, addict.Dict())
manager._apply_transformers_patches()
assert patch_calls["count"] == 0

View File

@@ -0,0 +1,66 @@
"""Tests for the HF Trainer context parallel patch."""
import pytest
from transformers import Trainer
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
GUARD_PATTERN,
PATCHED_GUARD,
patch_prepare_context_parallel_inputs,
)
@pytest.fixture
def restore_trainer_prepare_method():
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
original_method = getattr(
Trainer,
"_original_prepare_context_parallel_inputs",
Trainer._prepare_context_parallel_inputs,
)
patched_attr_present = hasattr(
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
)
yield
Trainer._prepare_context_parallel_inputs = original_method
if patched_attr_present:
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
def test_patch_attention_guard(restore_trainer_prepare_method):
"""Patch should swap the guard to allow sdpa or flash attention."""
# Ensure we start from the unpatched method
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
Trainer._prepare_context_parallel_inputs = (
Trainer._original_prepare_context_parallel_inputs
)
delattr(Trainer, "_original_prepare_context_parallel_inputs")
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
patch_prepare_context_parallel_inputs()
patched_method = Trainer._prepare_context_parallel_inputs
assert patched_method is not None
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
source = Trainer._axolotl_prepare_context_parallel_inputs_source
assert GUARD_PATTERN not in source
assert PATCHED_GUARD in source
def test_patch_is_idempotent(restore_trainer_prepare_method):
"""Calling the patch twice should leave the same patched function in place."""
patch_prepare_context_parallel_inputs()
first_patched = Trainer._prepare_context_parallel_inputs
patch_prepare_context_parallel_inputs()
second_patched = Trainer._prepare_context_parallel_inputs
assert first_patched is second_patched

View File

@@ -0,0 +1,111 @@
"""Unit tests for choosing the correct context parallel implementation."""
from types import SimpleNamespace
from axolotl.train import execute_training
from axolotl.utils.dict import DictDefault
class DummyTrainer:
"""Minimal trainer stub to exercise execute_training."""
def __init__(self):
self.model = object()
self.ref_model = None
self.accelerator = SimpleNamespace(torch_device_mesh=None)
self.train_called = False
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
self.train_called = True
class DummyPluginManager:
"""Minimal plugin manager stub."""
@staticmethod
def post_train(cfg, model): # pylint: disable=unused-argument
return None
class DummyContext:
"""Test context manager that records entries/exits."""
def __init__(self, recorder, **kwargs):
recorder.append({"kwargs": kwargs})
self.recorder = recorder
def __enter__(self):
self.recorder[-1]["entered"] = True
return self
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
self.recorder[-1]["exited"] = True
return False
def _base_cfg(**overrides):
base = {
"context_parallel_size": 2,
"gradient_accumulation_steps": 1,
"ring_attn_func": None,
"heads_k_stride": None,
"rl": None,
"flash_optimum": False,
}
base.update(overrides)
return DictDefault(base)
def test_execute_training_uses_ring_when_flash(monkeypatch):
"""FlashAttention CP should engage the custom ring context manager."""
recorder: list[dict] = []
monkeypatch.setattr(
"axolotl.train.SequenceParallelContextManager",
lambda **kwargs: DummyContext(recorder, **kwargs),
)
monkeypatch.setattr(
"axolotl.train.PluginManager.get_instance",
lambda: DummyPluginManager(),
)
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
trainer = DummyTrainer()
execute_training(cfg, trainer, resume_from_checkpoint=None)
assert trainer.train_called
assert len(recorder) == 1
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
assert recorder[0].get("entered") is True
assert recorder[0].get("exited") is True
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
"""SDPA CP should bypass the ring context manager."""
invoked = {"count": 0}
class NoOpContext:
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
return False
monkeypatch.setattr(
"axolotl.train.SequenceParallelContextManager",
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
or NoOpContext(),
)
monkeypatch.setattr(
"axolotl.train.PluginManager.get_instance",
lambda: DummyPluginManager(),
)
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
trainer = DummyTrainer()
execute_training(cfg, trainer, resume_from_checkpoint=None)
assert trainer.train_called
assert invoked["count"] == 0