precommit fixes
This commit is contained in:
@@ -11,10 +11,7 @@ from accelerate.state import PartialState
|
||||
ring_flash_attn_mock = MagicMock()
|
||||
with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}):
|
||||
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||
from axolotl.utils.collators.sequence_parallel import (
|
||||
adjust_position_ids_for_slice,
|
||||
check_for_boundary_splits,
|
||||
)
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
|
||||
|
||||
# Create a fixture for PartialState
|
||||
@@ -52,37 +49,6 @@ class TestSequenceParallelHelpers:
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
def test_check_for_boundary_splits(self):
|
||||
"""Test detection of boundaries near slice edges."""
|
||||
# Boundaries at positions 10, 25, 40
|
||||
boundaries = [10, 25, 40]
|
||||
|
||||
# Test case where two boundaries are near edges (one at start, one at end)
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=8, slice_end=30)
|
||||
assert (
|
||||
len(problems) == 2
|
||||
) # Both boundary at 10 (near start) and 25 (near end) are problems
|
||||
|
||||
# Check first problem - boundary near start
|
||||
assert problems[0][0] == 10 # The boundary position
|
||||
assert problems[0][1] == "start" # Type of issue
|
||||
assert problems[0][2] == 2 # Distance from start
|
||||
|
||||
# Check second problem - boundary near end
|
||||
assert problems[1][0] == 25 # The boundary position
|
||||
assert problems[1][1] == "end" # Type of issue
|
||||
assert problems[1][2] == 5 # Distance from end
|
||||
|
||||
# Test case with only one problem at the end
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=15, slice_end=27)
|
||||
assert len(problems) == 1 # Only boundary at 25 is near the end
|
||||
assert problems[0][0] == 25 # The boundary
|
||||
assert problems[0][1] == "end" # Type of issue
|
||||
|
||||
# Test case with no problems
|
||||
problems = check_for_boundary_splits(boundaries, slice_start=12, slice_end=20)
|
||||
assert len(problems) == 0
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
Reference in New Issue
Block a user