add support for CP + torch SDPA
This commit is contained in:
@@ -23,6 +23,8 @@ class TestSequenceParallelism:
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
flash_attention=True,
|
||||
sdp_attention=False,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
@@ -58,7 +60,8 @@ class TestSequenceParallelism:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"flash_attention": flash_attention,
|
||||
"sdp_attention": sdp_attention,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
@@ -132,3 +135,16 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=ring_attn_func,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
def test_sequence_parallel_training_sdpa(self, temp_dir):
|
||||
"""Smoke test for SDPA-based context parallelism."""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=False,
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=3.0,
|
||||
flash_attention=False,
|
||||
sdp_attention=True,
|
||||
)
|
||||
|
||||
74
tests/loaders/test_patch_manager_cp.py
Normal file
74
tests/loaders/test_patch_manager_cp.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Tests for PatchManager context parallel patch selection."""
|
||||
|
||||
import addict
|
||||
|
||||
from axolotl.loaders.patch_manager import PatchManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _stub_transformers_patches(monkeypatch):
|
||||
"""Replace trainer loss patchers with no-ops for isolation."""
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_evaluation_loop",
|
||||
lambda: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_loss_calc.patch_maybe_log_save_evaluate",
|
||||
lambda: None,
|
||||
)
|
||||
|
||||
|
||||
def test_patch_manager_applies_flash_cp_patch(monkeypatch):
|
||||
"""When flash attention is enabled, we patch Trainer for CP."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sdp_attention": False,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 1
|
||||
|
||||
|
||||
def test_patch_manager_skips_flash_patch_for_sdpa(monkeypatch):
|
||||
"""When only SDPA is requested, we should not patch Trainer."""
|
||||
_stub_transformers_patches(monkeypatch)
|
||||
|
||||
patch_calls = {"count": 0}
|
||||
|
||||
def stub_patch():
|
||||
patch_calls["count"] += 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.transformers.trainer_context_parallel.patch_prepare_context_parallel_inputs",
|
||||
stub_patch,
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
manager = PatchManager(cfg, addict.Dict())
|
||||
manager._apply_transformers_patches()
|
||||
|
||||
assert patch_calls["count"] == 0
|
||||
111
tests/test_train_context_parallel.py
Normal file
111
tests/test_train_context_parallel.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Unit tests for choosing the correct context parallel implementation."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from axolotl.train import execute_training
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class DummyTrainer:
|
||||
"""Minimal trainer stub to exercise execute_training."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = object()
|
||||
self.ref_model = None
|
||||
self.accelerator = SimpleNamespace(torch_device_mesh=None)
|
||||
self.train_called = False
|
||||
|
||||
def train(self, resume_from_checkpoint=None): # pylint: disable=unused-argument
|
||||
self.train_called = True
|
||||
|
||||
|
||||
class DummyPluginManager:
|
||||
"""Minimal plugin manager stub."""
|
||||
|
||||
@staticmethod
|
||||
def post_train(cfg, model): # pylint: disable=unused-argument
|
||||
return None
|
||||
|
||||
|
||||
class DummyContext:
|
||||
"""Test context manager that records entries/exits."""
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
recorder.append({"kwargs": kwargs})
|
||||
self.recorder = recorder
|
||||
|
||||
def __enter__(self):
|
||||
self.recorder[-1]["entered"] = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||
self.recorder[-1]["exited"] = True
|
||||
return False
|
||||
|
||||
|
||||
def _base_cfg(**overrides):
|
||||
base = {
|
||||
"context_parallel_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"ring_attn_func": None,
|
||||
"heads_k_stride": None,
|
||||
"rl": None,
|
||||
"flash_optimum": False,
|
||||
}
|
||||
base.update(overrides)
|
||||
return DictDefault(base)
|
||||
|
||||
|
||||
def test_execute_training_uses_ring_when_flash(monkeypatch):
|
||||
"""FlashAttention CP should engage the custom ring context manager."""
|
||||
recorder: list[dict] = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.SequenceParallelContextManager",
|
||||
lambda **kwargs: DummyContext(recorder, **kwargs),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.PluginManager.get_instance",
|
||||
lambda: DummyPluginManager(),
|
||||
)
|
||||
|
||||
cfg = _base_cfg(flash_attention=True, sdp_attention=False)
|
||||
trainer = DummyTrainer()
|
||||
|
||||
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||
|
||||
assert trainer.train_called
|
||||
assert len(recorder) == 1
|
||||
assert recorder[0]["kwargs"]["context_parallel_size"] == 2
|
||||
assert recorder[0].get("entered") is True
|
||||
assert recorder[0].get("exited") is True
|
||||
|
||||
|
||||
def test_execute_training_uses_transformers_cp_for_sdpa(monkeypatch):
|
||||
"""SDPA CP should bypass the ring context manager."""
|
||||
invoked = {"count": 0}
|
||||
|
||||
class NoOpContext:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb): # pylint: disable=unused-argument
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.SequenceParallelContextManager",
|
||||
lambda **kwargs: invoked.__setitem__("count", invoked["count"] + 1)
|
||||
or NoOpContext(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"axolotl.train.PluginManager.get_instance",
|
||||
lambda: DummyPluginManager(),
|
||||
)
|
||||
|
||||
cfg = _base_cfg(flash_attention=False, sdp_attention=True)
|
||||
trainer = DummyTrainer()
|
||||
|
||||
execute_training(cfg, trainer, resume_from_checkpoint=None)
|
||||
|
||||
assert trainer.train_called
|
||||
assert invoked["count"] == 0
|
||||
Reference in New Issue
Block a user