cleanup for failing test (#2436)
This commit is contained in:
@@ -27,7 +27,7 @@ def get_ring_attn_group() -> dist.ProcessGroup:
|
|||||||
return RING_ATTN_GROUP
|
return RING_ATTN_GROUP
|
||||||
|
|
||||||
|
|
||||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||||
"""
|
"""
|
||||||
Setter for ring attention group on this rank.
|
Setter for ring attention group on this rank.
|
||||||
|
|
||||||
|
|||||||
@@ -8,14 +8,13 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from accelerate.state import PartialState
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
|
get_ring_attn_group,
|
||||||
|
set_ring_attn_group,
|
||||||
|
)
|
||||||
|
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
# Use a single patch for ring_flash_attn if it's not available
|
|
||||||
ring_flash_attn_mock = MagicMock()
|
|
||||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
|
||||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
|
||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def partial_state():
|
def partial_state():
|
||||||
@@ -79,6 +78,22 @@ class TestSequenceParallelHelpers:
|
|||||||
class TestRingAttention:
|
class TestRingAttention:
|
||||||
"""Tests for the ring attention functionality."""
|
"""Tests for the ring attention functionality."""
|
||||||
|
|
||||||
|
@patch("torch.distributed.get_rank")
|
||||||
|
@patch("torch.distributed.get_world_size")
|
||||||
|
def test_get_ring_attn_group_no_registration(
|
||||||
|
self, mock_world_size, mock_rank, partial_state
|
||||||
|
):
|
||||||
|
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||||
|
# Setup mocks
|
||||||
|
mock_world_size.return_value = 4
|
||||||
|
mock_rank.return_value = 0
|
||||||
|
|
||||||
|
# Get the group without registration
|
||||||
|
group = get_ring_attn_group()
|
||||||
|
|
||||||
|
# Verify that None was returned
|
||||||
|
assert group is None
|
||||||
|
|
||||||
@patch("torch.distributed.new_group")
|
@patch("torch.distributed.new_group")
|
||||||
@patch("torch.distributed.get_rank")
|
@patch("torch.distributed.get_rank")
|
||||||
@patch("torch.distributed.get_world_size")
|
@patch("torch.distributed.get_world_size")
|
||||||
@@ -100,24 +115,11 @@ class TestRingAttention:
|
|||||||
# Verify the number of calls without examining the arguments
|
# Verify the number of calls without examining the arguments
|
||||||
assert mock_new_group.call_count == 2
|
assert mock_new_group.call_count == 2
|
||||||
|
|
||||||
# Just verify that new_group was called
|
# Verify that new_group was called
|
||||||
mock_new_group.assert_called()
|
mock_new_group.assert_called()
|
||||||
|
|
||||||
@patch("torch.distributed.get_rank")
|
# Clean up
|
||||||
@patch("torch.distributed.get_world_size")
|
set_ring_attn_group(None)
|
||||||
def test_get_ring_attn_group_no_registration(
|
|
||||||
self, mock_world_size, mock_rank, partial_state
|
|
||||||
):
|
|
||||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
|
||||||
# Setup mocks
|
|
||||||
mock_world_size.return_value = 4
|
|
||||||
mock_rank.return_value = 0
|
|
||||||
|
|
||||||
# Get the group without registration
|
|
||||||
group = get_ring_attn_group()
|
|
||||||
|
|
||||||
# Verify that None was returned
|
|
||||||
assert group is None
|
|
||||||
|
|
||||||
|
|
||||||
# Mock a simplified DataCollator test
|
# Mock a simplified DataCollator test
|
||||||
|
|||||||
Reference in New Issue
Block a user