add support for CP + torch SDPA
This commit is contained in:
@@ -23,6 +23,8 @@ class TestSequenceParallelism:
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=2.0,
|
||||
flash_attention=True,
|
||||
sdp_attention=False,
|
||||
):
|
||||
"""Helper method to run sequence parallel tests with different configurations"""
|
||||
cfg = DictDefault(
|
||||
@@ -58,7 +60,8 @@ class TestSequenceParallelism:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"flash_attention": flash_attention,
|
||||
"sdp_attention": sdp_attention,
|
||||
"loss_watchdog_threshold": 5.0,
|
||||
"loss_watchdog_patience": 3,
|
||||
"bf16": "auto",
|
||||
@@ -132,3 +135,16 @@ class TestSequenceParallelism:
|
||||
ring_attn_func=ring_attn_func,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
def test_sequence_parallel_training_sdpa(self, temp_dir):
|
||||
"""Smoke test for SDPA-based context parallelism."""
|
||||
self._run_sequence_parallel_test(
|
||||
temp_dir,
|
||||
sample_packing=False,
|
||||
micro_batch_size=1,
|
||||
pad_to_sequence_len=True,
|
||||
ring_attn_func=None,
|
||||
threshold=3.0,
|
||||
flash_attention=False,
|
||||
sdp_attention=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user