pytest
This commit is contained in:
221
tests/e2e/patched/test_sequence_parallelism.py
Normal file
221
tests/e2e/patched/test_sequence_parallelism.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
# Use a single patch for ring_flash_attn if it's not available
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.sequence_parallel import (
|
||||
adjust_position_ids_for_slice,
|
||||
check_for_boundary_splits,
|
||||
find_sample_boundaries,
|
||||
)
|
||||
|
||||
|
||||
# Create a fixture for PartialState
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
# This initializes a PartialState for a non-distributed environment
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
def test_find_sample_boundaries(self):
|
||||
"""Test detection of boundaries in position_ids."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples (boundary at index 5)
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples (boundaries at 3 and 7)
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
boundaries = find_sample_boundaries(position_ids)
|
||||
|
||||
assert len(boundaries) == 2
|
||||
assert boundaries[0] == [5] # First sequence has boundary at index 5
|
||||
assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7
|
||||
|
||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||
"""Test position_ids adjustment for sequence slices."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Adjust as if this was the second slice (start_idx = 4)
|
||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||
|
||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
def test_check_for_boundary_splits(self):
|
||||
"""Test detection of boundaries near slice edges."""
|
||||
# Boundaries at positions 10, 25, 40
|
||||
boundaries = [10, 25, 40]
|
||||
|
||||
# Test case where two boundaries are near edges (one at start, one at end)
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=8, slice_end=30)
|
||||
assert (
|
||||
len(problems) == 2
|
||||
) # Both boundary at 10 (near start) and 25 (near end) are problems
|
||||
|
||||
# Check first problem - boundary near start
|
||||
assert problems[0][0] == 10 # The boundary position
|
||||
assert problems[0][1] == "start" # Type of issue
|
||||
assert problems[0][2] == 2 # Distance from start
|
||||
|
||||
# Check second problem - boundary near end
|
||||
assert problems[1][0] == 25 # The boundary position
|
||||
assert problems[1][1] == "end" # Type of issue
|
||||
assert problems[1][2] == 5 # Distance from end
|
||||
|
||||
# Test case with only one problem at the end
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=15, slice_end=27)
|
||||
assert len(problems) == 1 # Only boundary at 25 is near the end
|
||||
assert problems[0][0] == 25 # The boundary
|
||||
assert problems[0][1] == "end" # Type of issue
|
||||
|
||||
# Test case with no problems
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=12, slice_end=20)
|
||||
assert len(problems) == 0
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_register_ring_attn(
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
from axolotl.monkeypatch.attention.ring_attn import register_ring_attn
|
||||
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
mock_rank.return_value = 3 # GPU #3
|
||||
mock_group = MagicMock()
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(sequence_parallel_size=4)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Just verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
|
||||
|
||||
# Mock a simplified DataCollator test
|
||||
@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_sequence_parallel_slicing(
|
||||
mock_world_size, mock_rank, mock_get_group, partial_state
|
||||
):
|
||||
"""Test the basic sequence slicing logic without full collator instantiation."""
|
||||
# Setup mocks
|
||||
mock_get_group.return_value = MagicMock()
|
||||
mock_rank.return_value = 1 # Second GPU
|
||||
mock_world_size.return_value = 4 # 4 GPUs total
|
||||
|
||||
# Create a sample batch
|
||||
batch = {
|
||||
"input_ids": torch.tensor(
|
||||
[
|
||||
[101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112],
|
||||
[201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212],
|
||||
]
|
||||
),
|
||||
"attention_mask": torch.ones(2, 12),
|
||||
}
|
||||
|
||||
# Simplified slicing logic from SequenceParallelDataCollator
|
||||
def slice_batch(batch, rank, world_size):
|
||||
result = {}
|
||||
for key in batch:
|
||||
seq_len = batch[key].shape[1]
|
||||
slice_size = seq_len // world_size
|
||||
start_idx = rank * slice_size
|
||||
end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len
|
||||
result[key] = batch[key][:, start_idx:end_idx]
|
||||
return result
|
||||
|
||||
# Slice the batch
|
||||
result = slice_batch(
|
||||
batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value
|
||||
)
|
||||
|
||||
# Check slicing
|
||||
assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU
|
||||
expected_input_ids = torch.tensor(
|
||||
[
|
||||
[104, 105, 106], # Second slice of first sequence
|
||||
[204, 205, 206], # Second slice of second sequence
|
||||
]
|
||||
)
|
||||
assert torch.all(result["input_ids"] == expected_input_ids)
|
||||
|
||||
|
||||
# Simple test for configuration validation
|
||||
@pytest.mark.parametrize(
|
||||
"config,should_validate",
|
||||
[
|
||||
({"sequence_parallel_size": 2, "flash_attention": True}, True),
|
||||
({"sequence_parallel_size": 2, "flash_attention": False}, False),
|
||||
({"sequence_parallel_size": 1, "flash_attention": False}, True),
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_requirements(config, should_validate):
|
||||
"""Test basic sequence parallelism configuration requirements."""
|
||||
|
||||
# Simple validation function that mimics the actual validator
|
||||
def validate_sp_config(config):
|
||||
if config.get("sequence_parallel_size", 1) > 1 and not config.get(
|
||||
"flash_attention", False
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
assert validate_sp_config(config) == should_validate
|
||||
Reference in New Issue
Block a user