This commit is contained in:
Dan Saunders
2025-03-10 21:18:04 +00:00
parent b44a207248
commit 4190ad0647
7 changed files with 187 additions and 432 deletions

View File

@@ -14,7 +14,6 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
from axolotl.utils.collators.sequence_parallel import (
adjust_position_ids_for_slice,
check_for_boundary_splits,
find_sample_boundaries,
)
@@ -30,24 +29,6 @@ def partial_state():
class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism."""
def test_find_sample_boundaries(self):
"""Test detection of boundaries in position_ids."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples (boundary at index 5)
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples (boundaries at 3 and 7)
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
boundaries = find_sample_boundaries(position_ids)
assert len(boundaries) == 2
assert boundaries[0] == [5] # First sequence has boundary at index 5
assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7
def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences