updates
This commit is contained in:
@@ -14,7 +14,6 @@ with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.utils.collators.sequence_parallel import (
|
||||
adjust_position_ids_for_slice,
|
||||
check_for_boundary_splits,
|
||||
find_sample_boundaries,
|
||||
)
|
||||
|
||||
|
||||
@@ -30,24 +29,6 @@ def partial_state():
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
def test_find_sample_boundaries(self):
|
||||
"""Test detection of boundaries in position_ids."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples (boundary at index 5)
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples (boundaries at 3 and 7)
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
boundaries = find_sample_boundaries(position_ids)
|
||||
|
||||
assert len(boundaries) == 2
|
||||
assert boundaries[0] == [5] # First sequence has boundary at index 5
|
||||
assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7
|
||||
|
||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||
"""Test position_ids adjustment for sequence slices."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
|
||||
Reference in New Issue
Block a user