cleanup for failing test (#2436)

This commit is contained in:
Dan Saunders
2025-03-22 17:53:29 -04:00
committed by GitHub
parent e44953d50c
commit 86bac48d14
2 changed files with 25 additions and 23 deletions

View File

@@ -27,7 +27,7 @@ def get_ring_attn_group() -> dist.ProcessGroup:
return RING_ATTN_GROUP
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup):
def set_ring_attn_group(ring_attn_group: dist.ProcessGroup | None):
"""
Setter for ring attention group on this rank.

View File

@@ -8,14 +8,13 @@ import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
set_ring_attn_group,
)
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
from axolotl.utils.dict import DictDefault
# Use a single patch for ring_flash_attn if it's not available
ring_flash_attn_mock = MagicMock()
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
@pytest.fixture
def partial_state():
@@ -79,6 +78,22 @@ class TestSequenceParallelHelpers:
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
@@ -100,24 +115,11 @@ class TestRingAttention:
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Just verify that new_group was called
# Verify that new_group was called
mock_new_group.assert_called()
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group returns None when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Get the group without registration
group = get_ring_attn_group()
# Verify that None was returned
assert group is None
# Clean up
set_ring_attn_group(None)
# Mock a simplified DataCollator test