From 86bac48d149ffd92ec854ddeb0e6547676cd2da1 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 22 Mar 2025 17:53:29 -0400 Subject: [PATCH] cleanup for failing test (#2436) --- .../monkeypatch/attention/ring_attn.py | 2 +- tests/e2e/patched/test_sp.py | 46 ++++++++++--------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py index 95c44a820..a81b87909 100644 --- a/src/axolotl/monkeypatch/attention/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -27,7 +27,7 @@ def get_ring_attn_group() -> dist.ProcessGroup: return RING_ATTN_GROUP -def set_ring_attn_group(ring_attn_group: dist.ProcessGroup): +def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None): """ Setter for ring attention group on this rank. diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index a20ad9ff2..903cb1450 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -8,14 +8,13 @@ import pytest import torch from accelerate.state import PartialState +from axolotl.monkeypatch.attention.ring_attn import ( + get_ring_attn_group, + set_ring_attn_group, +) +from axolotl.utils.collators.batching import adjust_position_ids_for_slice from axolotl.utils.dict import DictDefault -# Use a single patch for ring_flash_attn if it's not available -ring_flash_attn_mock = MagicMock() -with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): - from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - from axolotl.utils.collators.batching import adjust_position_ids_for_slice - @pytest.fixture def partial_state(): @@ -79,6 +78,22 @@ class TestSequenceParallelHelpers: class TestRingAttention: """Tests for the ring attention functionality.""" + @patch("torch.distributed.get_rank") + @patch("torch.distributed.get_world_size") + def test_get_ring_attn_group_no_registration( + self, mock_world_size, mock_rank, partial_state + ): + """Test that get_ring_attn_group returns None when no group has been registered.""" + # Setup mocks + mock_world_size.return_value = 4 + mock_rank.return_value = 0 + + # Get the group without registration + group = get_ring_attn_group() + + # Verify that None was returned + assert group is None + @patch("torch.distributed.new_group") @patch("torch.distributed.get_rank") @patch("torch.distributed.get_world_size") @@ -100,24 +115,11 @@ class TestRingAttention: # Verify the number of calls without examining the arguments assert mock_new_group.call_count == 2 - # Just verify that new_group was called + # Verify that new_group was called mock_new_group.assert_called() - @patch("torch.distributed.get_rank") - @patch("torch.distributed.get_world_size") - def test_get_ring_attn_group_no_registration( - self, mock_world_size, mock_rank, partial_state - ): - """Test that get_ring_attn_group returns None when no group has been registered.""" - # Setup mocks - mock_world_size.return_value = 4 - mock_rank.return_value = 0 - - # Get the group without registration - group = get_ring_attn_group() - - # Verify that None was returned - assert group is None + # Clean up + set_ring_attn_group(None) # Mock a simplified DataCollator test