From 09725be9907960721b8bd3b3f8f3b48a33fcc2fa Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 25 Sep 2025 12:03:43 -0400 Subject: [PATCH] add support for CP + torch SDPA --- src/axolotl/loaders/patch_manager.py | 4 +- .../monkeypatch/ring_attn/adapters/batch.py | 30 +++-- src/axolotl/monkeypatch/ring_attn/patch.py | 29 +++-- src/axolotl/train.py | 6 +- src/axolotl/utils/schemas/validation.py | 69 +++++------ tests/e2e/multigpu/patched/test_sp.py | 18 ++- tests/loaders/test_patch_manager_cp.py | 74 ++++++++++++ tests/test_train_context_parallel.py | 111 ++++++++++++++++++ 8 files changed, 274 insertions(+), 67 deletions(-) create mode 100644 tests/loaders/test_patch_manager_cp.py create mode 100644 tests/test_train_context_parallel.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 1e46f5c34..c6aa187b1 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -84,7 +84,9 @@ class PatchManager: patch_evaluation_loop() patch_maybe_log_save_evaluate() - if self.cfg.context_parallel_size > 1: + if self.cfg.context_parallel_size > 1 and getattr( + self.cfg, "flash_attention", False + ): from axolotl.monkeypatch.transformers.trainer_context_parallel import ( patch_prepare_context_parallel_inputs, ) diff --git a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py index 74d33ed4a..c39df3ac3 100644 --- a/src/axolotl/monkeypatch/ring_attn/adapters/batch.py +++ b/src/axolotl/monkeypatch/ring_attn/adapters/batch.py @@ -13,21 +13,10 @@ from typing import Callable import torch import torch.distributed as dist import transformers -import transformers.modeling_flash_attention_utils +import transformers.modeling_flash_attention_utils as flash_utils from ring_flash_attn import ring_flash_attn_func from ring_flash_attn.adapters.hf_adapter import check_params from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal - -try: - from transformers.modeling_flash_attention_utils import _flash_supports_window -except ImportError: - try: - from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size as _flash_supports_window, - ) - except ImportError: - _flash_supports_window = True - from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from axolotl.utils.schemas.enums import RingAttnFunc @@ -118,7 +107,7 @@ def create_flash_attn_forward_varlen_llama3( # Handle sliding window use_sliding_windows = ( - _flash_supports_window + _flash_windows_supported() and sliding_window is not None and key_states.shape[1] > sliding_window ) @@ -194,3 +183,18 @@ def substitute_hf_flash_attn( from ring_flash_attn.adapters.hf_adapter import flash_attention_forward ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward + + +def _flash_windows_supported() -> bool: + """Return whether current transformers build advertises sliding-window support.""" + support = getattr(flash_utils, "_flash_supports_window", None) + if support is None: + support = getattr(flash_utils, "_flash_supports_window_size", None) + + if support is None: + return True + + if callable(support): + return True + + return bool(support) diff --git a/src/axolotl/monkeypatch/ring_attn/patch.py b/src/axolotl/monkeypatch/ring_attn/patch.py index e1fd10b3a..17424292e 100644 --- a/src/axolotl/monkeypatch/ring_attn/patch.py +++ b/src/axolotl/monkeypatch/ring_attn/patch.py @@ -13,18 +13,9 @@ from typing import Optional import torch import torch.distributed as dist +import transformers.modeling_flash_attention_utils as flash_utils from torch.distributed import DeviceMesh -try: - from transformers.modeling_flash_attention_utils import _flash_supports_window -except ImportError: - try: - from transformers.modeling_flash_attention_utils import ( - _flash_supports_window_size as _flash_supports_window, - ) - except ImportError: - _flash_supports_window = True - from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RingAttnFunc @@ -83,7 +74,7 @@ def create_ring_flash_attention_forward( # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). use_sliding_windows = ( - _flash_supports_window + _flash_windows_supported() and sliding_window is not None and key_states.shape[1] > sliding_window ) @@ -225,3 +216,19 @@ def update_ring_attn_params(position_ids: torch.Tensor | None): cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids) cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device()) update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group()) + + +def _flash_windows_supported() -> bool: + """Best-effort check for FlashAttention sliding-window support.""" + support = getattr(flash_utils, "_flash_supports_window", None) + if support is None: + support = getattr(flash_utils, "_flash_supports_window_size", None) + + if support is None: + return True + + if callable(support): + # Signature differs across versions; assume support when callable. + return True + + return bool(support) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 2a70d9712..04ec8dde9 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -179,7 +179,11 @@ def execute_training( ) ) - if cfg.context_parallel_size > 1: + use_flash_cp = cfg.context_parallel_size > 1 and bool( + getattr(cfg, "flash_attention", False) + ) + + if use_flash_cp: models = [trainer.model] if hasattr(trainer, "ref_model") and trainer.ref_model: models.append(trainer.ref_model) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 9671b10ae..a5f4a25dd 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1,7 +1,6 @@ """Module with validation methods for config pydantic model.""" import json -import sys import tempfile from pathlib import Path @@ -1314,50 +1313,40 @@ class ComplexValidationMixin: if not self.context_parallel_size: self.context_parallel_size = 1 elif self.context_parallel_size > 1: - if not self.flash_attention: + use_flash_attention = getattr(self, "flash_attention", False) + use_sdp_attention = getattr(self, "sdp_attention", False) + + if not (use_flash_attention or use_sdp_attention): raise ValueError( - "flash_attention: true must be set with context_parallel_size > 1" + "context_parallel_size > 1 requires either flash_attention: true " + "or sdp_attention: true" ) - if self.sample_packing and self.micro_batch_size > 1: - raise ValueError( - "micro_batch_size must be set to 1 when sample_packing is enabled " - "due to a `ring-flash-attn` requirement" + if use_flash_attention: + if self.sample_packing and self.micro_batch_size > 1: + raise ValueError( + "micro_batch_size must be set to 1 when sample_packing is enabled " + "due to a `ring-flash-attn` requirement" + ) + + try: + import ring_flash_attn # noqa: F401 # Required after monkey-patching + except ImportError as exception: + raise ImportError( + "context_parallel_size > 1 but ring_flash_attn is not installed. " + "Please install it with `pip install axolotl[ring-flash-attn] " + "or `pip install ring-flash-attn>=0.1.4`." + ) from exception + + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"context_parallel_size={self.context_parallel_size}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." ) - try: - import transformers.modeling_flash_attention_utils - from transformers.utils import is_flash_attn_greater_or_equal - - transformers.modeling_flash_attention_utils._flash_supports_window = ( - True - ) - sys.modules[ - "transformers.modeling_flash_attention_utils" - ]._flash_supports_window = True - sys.modules[ - "transformers.modeling_flash_attention_utils" - ]._flash_supports_window_size = True - sys.modules[ - "transformers.modeling_flash_attention_utils" - ].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal - import ring_flash_attn # noqa: F401 # Required after monkey-patching - except ImportError as exception: - raise ImportError( - "context_parallel_size > 1 but ring_flash_attn is not installed. " - "Please install it with `pip install axolotl[ring-flash-attn] " - "or `pip install ring-flash-attn>=0.1.4`." - ) from exception - - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"context_parallel_size={self.context_parallel_size}. " - "Please note that logged losses may differ slightly to the non-SP " - "losses due to transformers Trainer implementation details. " - "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) - return self @model_validator(mode="after") diff --git a/tests/e2e/multigpu/patched/test_sp.py b/tests/e2e/multigpu/patched/test_sp.py index a005e6742..14f9a77ea 100644 --- a/tests/e2e/multigpu/patched/test_sp.py +++ b/tests/e2e/multigpu/patched/test_sp.py @@ -23,6 +23,8 @@ class TestSequenceParallelism: pad_to_sequence_len=True, ring_attn_func=None, threshold=2.0, + flash_attention=True, + sdp_attention=False, ): """Helper method to run sequence parallel tests with different configurations""" cfg = DictDefault( @@ -58,7 +60,8 @@ class TestSequenceParallelism: "learning_rate": 0.00001, "optimizer": "adamw_8bit", "lr_scheduler": "cosine", - "flash_attention": True, + "flash_attention": flash_attention, + "sdp_attention": sdp_attention, "loss_watchdog_threshold": 5.0, "loss_watchdog_patience": 3, "bf16": "auto", @@ -132,3 +135,16 @@ class TestSequenceParallelism: ring_attn_func=ring_attn_func, threshold=threshold, ) + + def test_sequence_parallel_training_sdpa(self, temp_dir): + """Smoke test for SDPA-based context parallelism.""" + self._run_sequence_parallel_test( + temp_dir, + sample_packing=False, + micro_batch_size=1, + pad_to_sequence_len=True, + ring_attn_func=None, + threshold=3.0, + flash_attention=False, + sdp_attention=True, + ) diff --git a/tests/loaders/test_patch_manager_cp.py b/tests/loaders/test_patch_manager_cp.py new file mode 100644 index 000000000..0ff824f50 --- /dev/null +++ b/tests/loaders/test_patch_manager_cp.py @@ -0,0 +1,74 @@ +"""Tests for PatchManager context parallel patch selection.""" + +import addict + +from axolotl.loaders.patch_manager import PatchManager +from axolotl.utils.dict import DictDefault + + +def _stub_transformers_patches(monkeypatch): + """Replace trainer loss patchers with no-ops for isolation.""" + monkeypatch.setattr( + "axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop", + lambda: None, + ) + monkeypatch.setattr( + "axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate", + lambda: None, + ) + + +def test_patch_manager_applies_flash_cp_patch(monkeypatch): + """When flash attention is enabled, we patch Trainer for CP.""" + _stub_transformers_patches(monkeypatch) + + patch_calls = {"count": 0} + + def stub_patch(): + patch_calls["count"] += 1 + + monkeypatch.setattr( + "axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs", + stub_patch, + ) + + cfg = DictDefault( + { + "context_parallel_size": 2, + "flash_attention": True, + "sdp_attention": False, + } + ) + + manager = PatchManager(cfg, addict.Dict()) + manager._apply_transformers_patches() + + assert patch_calls["count"] == 1 + + +def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch): + """When only SDPA is requested, we should not patch Trainer.""" + _stub_transformers_patches(monkeypatch) + + patch_calls = {"count": 0} + + def stub_patch(): + patch_calls["count"] += 1 + + monkeypatch.setattr( + "axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs", + stub_patch, + ) + + cfg = DictDefault( + { + "context_parallel_size": 2, + "flash_attention": False, + "sdp_attention": True, + } + ) + + manager = PatchManager(cfg, addict.Dict()) + manager._apply_transformers_patches() + + assert patch_calls["count"] == 0 diff --git a/tests/test_train_context_parallel.py b/tests/test_train_context_parallel.py new file mode 100644 index 000000000..4772487f6 --- /dev/null +++ b/tests/test_train_context_parallel.py @@ -0,0 +1,111 @@ +"""Unit tests for choosing the correct context parallel implementation.""" + +from types import SimpleNamespace + +from axolotl.train import execute_training +from axolotl.utils.dict import DictDefault + + +class DummyTrainer: + """Minimal trainer stub to exercise execute_training.""" + + def __init__(self): + self.model = object() + self.ref_model = None + self.accelerator = SimpleNamespace(torch_device_mesh=None) + self.train_called = False + + def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument + self.train_called = True + + +class DummyPluginManager: + """Minimal plugin manager stub.""" + + @staticmethod + def post_train(cfg, model): # pylint: disable=unused-argument + return None + + +class DummyContext: + """Test context manager that records entries/exits.""" + + def __init__(self, recorder, **kwargs): + recorder.append({"kwargs": kwargs}) + self.recorder = recorder + + def __enter__(self): + self.recorder[-1]["entered"] = True + return self + + def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument + self.recorder[-1]["exited"] = True + return False + + +def _base_cfg(**overrides): + base = { + "context_parallel_size": 2, + "gradient_accumulation_steps": 1, + "ring_attn_func": None, + "heads_k_stride": None, + "rl": None, + "flash_optimum": False, + } + base.update(overrides) + return DictDefault(base) + + +def test_execute_training_uses_ring_when_flash(monkeypatch): + """FlashAttention CP should engage the custom ring context manager.""" + recorder: list[dict] = [] + + monkeypatch.setattr( + "axolotl.train.SequenceParallelContextManager", + lambda **kwargs: DummyContext(recorder, **kwargs), + ) + monkeypatch.setattr( + "axolotl.train.PluginManager.get_instance", + lambda: DummyPluginManager(), + ) + + cfg = _base_cfg(flash_attention=True, sdp_attention=False) + trainer = DummyTrainer() + + execute_training(cfg, trainer, resume_from_checkpoint=None) + + assert trainer.train_called + assert len(recorder) == 1 + assert recorder[0]["kwargs"]["context_parallel_size"] == 2 + assert recorder[0].get("entered") is True + assert recorder[0].get("exited") is True + + +def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch): + """SDPA CP should bypass the ring context manager.""" + invoked = {"count": 0} + + class NoOpContext: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument + return False + + monkeypatch.setattr( + "axolotl.train.SequenceParallelContextManager", + lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1) + or NoOpContext(), + ) + monkeypatch.setattr( + "axolotl.train.PluginManager.get_instance", + lambda: DummyPluginManager(), + ) + + cfg = _base_cfg(flash_attention=False, sdp_attention=True) + trainer = DummyTrainer() + + execute_training(cfg, trainer, resume_from_checkpoint=None) + + assert trainer.train_called + assert invoked["count"] == 0