Compare commits
5 Commits
version-de
...
cp-sdpa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09725be990 | ||
|
|
f9bd6936c1 | ||
|
|
b9a3bfee5a | ||
|
|
08124a7c92 | ||
|
|
56e0a77e0d |
@@ -84,6 +84,15 @@ class PatchManager:
|
|||||||
patch_evaluation_loop()
|
patch_evaluation_loop()
|
||||||
patch_maybe_log_save_evaluate()
|
patch_maybe_log_save_evaluate()
|
||||||
|
|
||||||
|
if self.cfg.context_parallel_size > 1 and getattr(
|
||||||
|
self.cfg, "flash_attention", False
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||||
|
patch_prepare_context_parallel_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
|
|||||||
@@ -13,21 +13,10 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
import transformers
|
||||||
import transformers.modeling_flash_attention_utils
|
import transformers.modeling_flash_attention_utils as flash_utils
|
||||||
from ring_flash_attn import ring_flash_attn_func
|
from ring_flash_attn import ring_flash_attn_func
|
||||||
from ring_flash_attn.adapters.hf_adapter import check_params
|
from ring_flash_attn.adapters.hf_adapter import check_params
|
||||||
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import (
|
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
_flash_supports_window = True
|
|
||||||
|
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
@@ -118,7 +107,7 @@ def create_flash_attn_forward_varlen_llama3(
|
|||||||
|
|
||||||
# Handle sliding window
|
# Handle sliding window
|
||||||
use_sliding_windows = (
|
use_sliding_windows = (
|
||||||
_flash_supports_window
|
_flash_windows_supported()
|
||||||
and sliding_window is not None
|
and sliding_window is not None
|
||||||
and key_states.shape[1] > sliding_window
|
and key_states.shape[1] > sliding_window
|
||||||
)
|
)
|
||||||
@@ -194,3 +183,18 @@ def substitute_hf_flash_attn(
|
|||||||
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
from ring_flash_attn.adapters.hf_adapter import flash_attention_forward
|
||||||
|
|
||||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_windows_supported() -> bool:
|
||||||
|
"""Return whether current transformers build advertises sliding-window support."""
|
||||||
|
support = getattr(flash_utils, "_flash_supports_window", None)
|
||||||
|
if support is None:
|
||||||
|
support = getattr(flash_utils, "_flash_supports_window_size", None)
|
||||||
|
|
||||||
|
if support is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if callable(support):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return bool(support)
|
||||||
|
|||||||
@@ -13,18 +13,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import transformers.modeling_flash_attention_utils as flash_utils
|
||||||
from torch.distributed import DeviceMesh
|
from torch.distributed import DeviceMesh
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import _flash_supports_window
|
|
||||||
except ImportError:
|
|
||||||
try:
|
|
||||||
from transformers.modeling_flash_attention_utils import (
|
|
||||||
_flash_supports_window_size as _flash_supports_window,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
_flash_supports_window = True
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
@@ -83,7 +74,7 @@ def create_ring_flash_attention_forward(
|
|||||||
|
|
||||||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
|
||||||
use_sliding_windows = (
|
use_sliding_windows = (
|
||||||
_flash_supports_window
|
_flash_windows_supported()
|
||||||
and sliding_window is not None
|
and sliding_window is not None
|
||||||
and key_states.shape[1] > sliding_window
|
and key_states.shape[1] > sliding_window
|
||||||
)
|
)
|
||||||
@@ -225,3 +216,19 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
|||||||
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_seqlens, _ = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
cu_seqlens = cu_seqlens.squeeze().to(device=torch.cuda.current_device())
|
||||||
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
update_ring_flash_attn_params(cu_seqlens, get_ring_attn_group())
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_windows_supported() -> bool:
|
||||||
|
"""Best-effort check for FlashAttention sliding-window support."""
|
||||||
|
support = getattr(flash_utils, "_flash_supports_window", None)
|
||||||
|
if support is None:
|
||||||
|
support = getattr(flash_utils, "_flash_supports_window_size", None)
|
||||||
|
|
||||||
|
if support is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if callable(support):
|
||||||
|
# Signature differs across versions; assume support when callable.
|
||||||
|
return True
|
||||||
|
|
||||||
|
return bool(support)
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
"""Monkey patch to allow context parallelism with FlashAttention in HF Trainer."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import detab_code
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
GUARD_PATTERN = 'if model.config._attn_implementation != "sdpa":'
|
||||||
|
PATCHED_GUARD = (
|
||||||
|
'if model.config._attn_implementation not in ("sdpa", "flash_attention_2"):'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_prepare_context_parallel_inputs() -> None:
|
||||||
|
"""Relax the SDPA-only guard when running context parallelism with FlashAttention."""
|
||||||
|
if getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False):
|
||||||
|
LOG.debug("Trainer._prepare_context_parallel_inputs already patched")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_source = inspect.getsource(Trainer._prepare_context_parallel_inputs)
|
||||||
|
except OSError as exc: # pragma: no cover - occurs when source is unavailable
|
||||||
|
LOG.warning("Unable to patch Trainer._prepare_context_parallel_inputs: %s", exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
if GUARD_PATTERN not in original_source:
|
||||||
|
LOG.warning(
|
||||||
|
"Expected guard not found in Trainer._prepare_context_parallel_inputs; \n"
|
||||||
|
"skipping FlashAttention context parallelism patch"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
patched_source = original_source.replace(GUARD_PATTERN, PATCHED_GUARD)
|
||||||
|
patched_source, _ = detab_code(patched_source)
|
||||||
|
patched_source = patched_source.replace(
|
||||||
|
"def _prepare_context_parallel_inputs(",
|
||||||
|
"def axolotl_prepare_context_parallel_inputs(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
module_name = Trainer.__module__
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# import symbols referenced in the method so exec can succeed
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(module):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec(f"from {module_name} import ({', '.join(items_to_import)})", globals())
|
||||||
|
exec(patched_source, globals())
|
||||||
|
|
||||||
|
Trainer._original_prepare_context_parallel_inputs = (
|
||||||
|
Trainer._prepare_context_parallel_inputs
|
||||||
|
)
|
||||||
|
Trainer._prepare_context_parallel_inputs = axolotl_prepare_context_parallel_inputs
|
||||||
|
Trainer._axolotl_prepare_context_parallel_inputs_source = patched_source
|
||||||
|
Trainer._axolotl_prepare_context_parallel_inputs_patched = True
|
||||||
|
LOG.debug(
|
||||||
|
"Patched Trainer._prepare_context_parallel_inputs for FlashAttention + CP"
|
||||||
|
)
|
||||||
@@ -179,7 +179,11 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.context_parallel_size > 1:
|
use_flash_cp = cfg.context_parallel_size > 1 and bool(
|
||||||
|
getattr(cfg, "flash_attention", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_flash_cp:
|
||||||
models = [trainer.model]
|
models = [trainer.model]
|
||||||
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
if hasattr(trainer, "ref_model") and trainer.ref_model:
|
||||||
models.append(trainer.ref_model)
|
models.append(trainer.ref_model)
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Module with validation methods for config pydantic model."""
|
"""Module with validation methods for config pydantic model."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -1314,50 +1313,40 @@ class ComplexValidationMixin:
|
|||||||
if not self.context_parallel_size:
|
if not self.context_parallel_size:
|
||||||
self.context_parallel_size = 1
|
self.context_parallel_size = 1
|
||||||
elif self.context_parallel_size > 1:
|
elif self.context_parallel_size > 1:
|
||||||
if not self.flash_attention:
|
use_flash_attention = getattr(self, "flash_attention", False)
|
||||||
|
use_sdp_attention = getattr(self, "sdp_attention", False)
|
||||||
|
|
||||||
|
if not (use_flash_attention or use_sdp_attention):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"flash_attention: true must be set with context_parallel_size > 1"
|
"context_parallel_size > 1 requires either flash_attention: true "
|
||||||
|
"or sdp_attention: true"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.sample_packing and self.micro_batch_size > 1:
|
if use_flash_attention:
|
||||||
raise ValueError(
|
if self.sample_packing and self.micro_batch_size > 1:
|
||||||
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
raise ValueError(
|
||||||
"due to a `ring-flash-attn` requirement"
|
"micro_batch_size must be set to 1 when sample_packing is enabled "
|
||||||
|
"due to a `ring-flash-attn` requirement"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ring_flash_attn # noqa: F401 # Required after monkey-patching
|
||||||
|
except ImportError as exception:
|
||||||
|
raise ImportError(
|
||||||
|
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
||||||
|
"Please install it with `pip install axolotl[ring-flash-attn] "
|
||||||
|
"or `pip install ring-flash-attn>=0.1.4`."
|
||||||
|
) from exception
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"Sequence parallelism (SP) is enabled with "
|
||||||
|
f"context_parallel_size={self.context_parallel_size}. "
|
||||||
|
"Please note that logged losses may differ slightly to the non-SP "
|
||||||
|
"losses due to transformers Trainer implementation details. "
|
||||||
|
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
||||||
|
"for more details."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
import transformers.modeling_flash_attention_utils
|
|
||||||
from transformers.utils import is_flash_attn_greater_or_equal
|
|
||||||
|
|
||||||
transformers.modeling_flash_attention_utils._flash_supports_window = (
|
|
||||||
True
|
|
||||||
)
|
|
||||||
sys.modules[
|
|
||||||
"transformers.modeling_flash_attention_utils"
|
|
||||||
]._flash_supports_window = True
|
|
||||||
sys.modules[
|
|
||||||
"transformers.modeling_flash_attention_utils"
|
|
||||||
]._flash_supports_window_size = True
|
|
||||||
sys.modules[
|
|
||||||
"transformers.modeling_flash_attention_utils"
|
|
||||||
].is_flash_attn_greater_or_equal = is_flash_attn_greater_or_equal
|
|
||||||
import ring_flash_attn # noqa: F401 # Required after monkey-patching
|
|
||||||
except ImportError as exception:
|
|
||||||
raise ImportError(
|
|
||||||
"context_parallel_size > 1 but ring_flash_attn is not installed. "
|
|
||||||
"Please install it with `pip install axolotl[ring-flash-attn] "
|
|
||||||
"or `pip install ring-flash-attn>=0.1.4`."
|
|
||||||
) from exception
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"Sequence parallelism (SP) is enabled with "
|
|
||||||
f"context_parallel_size={self.context_parallel_size}. "
|
|
||||||
"Please note that logged losses may differ slightly to the non-SP "
|
|
||||||
"losses due to transformers Trainer implementation details. "
|
|
||||||
"Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 "
|
|
||||||
"for more details."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ class TestSequenceParallelism:
|
|||||||
pad_to_sequence_len=True,
|
pad_to_sequence_len=True,
|
||||||
ring_attn_func=None,
|
ring_attn_func=None,
|
||||||
threshold=2.0,
|
threshold=2.0,
|
||||||
|
flash_attention=True,
|
||||||
|
sdp_attention=False,
|
||||||
):
|
):
|
||||||
"""Helper method to run sequence parallel tests with different configurations"""
|
"""Helper method to run sequence parallel tests with different configurations"""
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -58,7 +60,8 @@ class TestSequenceParallelism:
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"flash_attention": True,
|
"flash_attention": flash_attention,
|
||||||
|
"sdp_attention": sdp_attention,
|
||||||
"loss_watchdog_threshold": 5.0,
|
"loss_watchdog_threshold": 5.0,
|
||||||
"loss_watchdog_patience": 3,
|
"loss_watchdog_patience": 3,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
@@ -132,3 +135,16 @@ class TestSequenceParallelism:
|
|||||||
ring_attn_func=ring_attn_func,
|
ring_attn_func=ring_attn_func,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sequence_parallel_training_sdpa(self, temp_dir):
|
||||||
|
"""Smoke test for SDPA-based context parallelism."""
|
||||||
|
self._run_sequence_parallel_test(
|
||||||
|
temp_dir,
|
||||||
|
sample_packing=False,
|
||||||
|
micro_batch_size=1,
|
||||||
|
pad_to_sequence_len=True,
|
||||||
|
ring_attn_func=None,
|
||||||
|
threshold=3.0,
|
||||||
|
flash_attention=False,
|
||||||
|
sdp_attention=True,
|
||||||
|
)
|
||||||
|
|||||||
74
tests/loaders/test_patch_manager_cp.py
Normal file
74
tests/loaders/test_patch_manager_cp.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Tests for PatchManager context parallel patch selection."""
|
||||||
|
|
||||||
|
import addict
|
||||||
|
|
||||||
|
from axolotl.loaders.patch_manager import PatchManager
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
def _stub_transformers_patches(monkeypatch):
|
||||||
|
"""Replace trainer loss patchers with no-ops for isolation."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
|
||||||
|
lambda: None,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
|
||||||
|
lambda: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
|
||||||
|
"""When flash attention is enabled, we patch Trainer for CP."""
|
||||||
|
_stub_transformers_patches(monkeypatch)
|
||||||
|
|
||||||
|
patch_calls = {"count": 0}
|
||||||
|
|
||||||
|
def stub_patch():
|
||||||
|
patch_calls["count"] += 1
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||||
|
stub_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"context_parallel_size": 2,
|
||||||
|
"flash_attention": True,
|
||||||
|
"sdp_attention": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = PatchManager(cfg, addict.Dict())
|
||||||
|
manager._apply_transformers_patches()
|
||||||
|
|
||||||
|
assert patch_calls["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
|
||||||
|
"""When only SDPA is requested, we should not patch Trainer."""
|
||||||
|
_stub_transformers_patches(monkeypatch)
|
||||||
|
|
||||||
|
patch_calls = {"count": 0}
|
||||||
|
|
||||||
|
def stub_patch():
|
||||||
|
patch_calls["count"] += 1
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||||
|
stub_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"context_parallel_size": 2,
|
||||||
|
"flash_attention": False,
|
||||||
|
"sdp_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = PatchManager(cfg, addict.Dict())
|
||||||
|
manager._apply_transformers_patches()
|
||||||
|
|
||||||
|
assert patch_calls["count"] == 0
|
||||||
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Tests for the HF Trainer context parallel patch."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||||
|
GUARD_PATTERN,
|
||||||
|
PATCHED_GUARD,
|
||||||
|
patch_prepare_context_parallel_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def restore_trainer_prepare_method():
|
||||||
|
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
||||||
|
original_method = getattr(
|
||||||
|
Trainer,
|
||||||
|
"_original_prepare_context_parallel_inputs",
|
||||||
|
Trainer._prepare_context_parallel_inputs,
|
||||||
|
)
|
||||||
|
patched_attr_present = hasattr(
|
||||||
|
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
Trainer._prepare_context_parallel_inputs = original_method
|
||||||
|
if patched_attr_present:
|
||||||
|
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||||
|
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||||
|
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||||
|
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
||||||
|
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||||
|
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||||
|
# Ensure we start from the unpatched method
|
||||||
|
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||||
|
Trainer._prepare_context_parallel_inputs = (
|
||||||
|
Trainer._original_prepare_context_parallel_inputs
|
||||||
|
)
|
||||||
|
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||||
|
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
||||||
|
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||||
|
|
||||||
|
patch_prepare_context_parallel_inputs()
|
||||||
|
|
||||||
|
patched_method = Trainer._prepare_context_parallel_inputs
|
||||||
|
assert patched_method is not None
|
||||||
|
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||||
|
|
||||||
|
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||||
|
assert GUARD_PATTERN not in source
|
||||||
|
assert PATCHED_GUARD in source
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
||||||
|
"""Calling the patch twice should leave the same patched function in place."""
|
||||||
|
patch_prepare_context_parallel_inputs()
|
||||||
|
first_patched = Trainer._prepare_context_parallel_inputs
|
||||||
|
|
||||||
|
patch_prepare_context_parallel_inputs()
|
||||||
|
second_patched = Trainer._prepare_context_parallel_inputs
|
||||||
|
|
||||||
|
assert first_patched is second_patched
|
||||||
111
tests/test_train_context_parallel.py
Normal file
111
tests/test_train_context_parallel.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Unit tests for choosing the correct context parallel implementation."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from axolotl.train import execute_training
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTrainer:
|
||||||
|
"""Minimal trainer stub to exercise execute_training."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.model = object()
|
||||||
|
self.ref_model = None
|
||||||
|
self.accelerator = SimpleNamespace(torch_device_mesh=None)
|
||||||
|
self.train_called = False
|
||||||
|
|
||||||
|
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
|
||||||
|
self.train_called = True
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPluginManager:
|
||||||
|
"""Minimal plugin manager stub."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def post_train(cfg, model): # pylint: disable=unused-argument
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DummyContext:
|
||||||
|
"""Test context manager that records entries/exits."""
|
||||||
|
|
||||||
|
def __init__(self, recorder, **kwargs):
|
||||||
|
recorder.append({"kwargs": kwargs})
|
||||||
|
self.recorder = recorder
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.recorder[-1]["entered"] = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||||
|
self.recorder[-1]["exited"] = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _base_cfg(**overrides):
|
||||||
|
base = {
|
||||||
|
"context_parallel_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"ring_attn_func": None,
|
||||||
|
"heads_k_stride": None,
|
||||||
|
"rl": None,
|
||||||
|
"flash_optimum": False,
|
||||||
|
}
|
||||||
|
base.update(overrides)
|
||||||
|
return DictDefault(base)
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_training_uses_ring_when_flash(monkeypatch):
|
||||||
|
"""FlashAttention CP should engage the custom ring context manager."""
|
||||||
|
recorder: list[dict] = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.train.SequenceParallelContextManager",
|
||||||
|
lambda **kwargs: DummyContext(recorder, **kwargs),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.train.PluginManager.get_instance",
|
||||||
|
lambda: DummyPluginManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
|
||||||
|
trainer = DummyTrainer()
|
||||||
|
|
||||||
|
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||||
|
|
||||||
|
assert trainer.train_called
|
||||||
|
assert len(recorder) == 1
|
||||||
|
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
|
||||||
|
assert recorder[0].get("entered") is True
|
||||||
|
assert recorder[0].get("exited") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
|
||||||
|
"""SDPA CP should bypass the ring context manager."""
|
||||||
|
invoked = {"count": 0}
|
||||||
|
|
||||||
|
class NoOpContext:
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.train.SequenceParallelContextManager",
|
||||||
|
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
|
||||||
|
or NoOpContext(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"axolotl.train.PluginManager.get_instance",
|
||||||
|
lambda: DummyPluginManager(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
|
||||||
|
trainer = DummyTrainer()
|
||||||
|
|
||||||
|
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||||
|
|
||||||
|
assert trainer.train_called
|
||||||
|
assert invoked["count"] == 0
|
||||||
Reference in New Issue
Block a user