cleanup for failing test (#2436)
This commit is contained in:
@@ -27,7 +27,7 @@ def get_ring_attn_group() -> dist.ProcessGroup:
|
||||
return RING_ATTN_GROUP
|
||||
|
||||
|
||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
|
||||
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
|
||||
"""
|
||||
Setter for ring attention group on this rank.
|
||||
|
||||
|
||||
@@ -8,14 +8,13 @@ import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
# Use a single patch for ring_flash_attn if it's not available
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
@@ -79,6 +78,22 @@ class TestSequenceParallelHelpers:
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
@@ -100,24 +115,11 @@ class TestRingAttention:
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Just verify that new_group was called
|
||||
# Verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
# Clean up
|
||||
set_ring_attn_group(None)
|
||||
|
||||
|
||||
# Mock a simplified DataCollator test
|
||||
|
||||
Reference in New Issue
Block a user