SP context manager update (#2699)

* utilize accelerate prepare_data_loader with patching

* lint

* cleanup, fix

* update to support DPO quirk

* coderabbit commits, cleanup, remove dead code

* fix

* move ring attn patching to sp ctx manager

* lint

* lint

* test fix

* test fix
This commit is contained in:
Dan Saunders
2025-05-22 11:18:32 -04:00
committed by GitHub
parent aa0492c366
commit 5f8f817200
5 changed files with 68 additions and 46 deletions

View File

@@ -84,16 +84,16 @@ class TestRingAttention:
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@@ -323,8 +323,11 @@ class TestApplySequenceParallelism:
lambda **kwargs: None,
)
def test_world_size_one(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
@@ -336,8 +339,11 @@ class TestApplySequenceParallelism:
# Should return the original batch unchanged
assert result == sequence_parallel_batch
def test_batch_ring_rank0(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
@@ -359,8 +365,11 @@ class TestApplySequenceParallelism:
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
def test_batch_ring_rank1(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
@@ -419,8 +428,13 @@ class TestApplySequenceParallelism:
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
def test_partial_application(self, sequence_parallel_batch):
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()