add support for CP + torch SDPA

This commit is contained in:
Dan Saunders
2025-09-25 12:03:43 -04:00
parent f9bd6936c1
commit 09725be990
8 changed files with 274 additions and 67 deletions

View File

@@ -23,6 +23,8 @@ class TestSequenceParallelism:
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=2.0,
flash_attention=True,
sdp_attention=False,
):
"""Helper method to run sequence parallel tests with different configurations"""
cfg = DictDefault(
@@ -58,7 +60,8 @@ class TestSequenceParallelism:
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"flash_attention": flash_attention,
"sdp_attention": sdp_attention,
"loss_watchdog_threshold": 5.0,
"loss_watchdog_patience": 3,
"bf16": "auto",
@@ -132,3 +135,16 @@ class TestSequenceParallelism:
ring_attn_func=ring_attn_func,
threshold=threshold,
)
def test_sequence_parallel_training_sdpa(self, temp_dir):
"""Smoke test for SDPA-based context parallelism."""
self._run_sequence_parallel_test(
temp_dir,
sample_packing=False,
micro_batch_size=1,
pad_to_sequence_len=True,
ring_attn_func=None,
threshold=3.0,
flash_attention=False,
sdp_attention=True,
)