pre-commit fix
This commit is contained in:
@@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import (
|
|||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
set_ring_attn_group,
|
set_ring_attn_group,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -48,33 +47,6 @@ def fixture_cfg():
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
class TestSequenceParallelHelpers:
|
|
||||||
"""Test helper functions used in sequence parallelism."""
|
|
||||||
|
|
||||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
|
||||||
"""Test position_ids adjustment for sequence slices."""
|
|
||||||
# Create sample position_ids with multiple sequences
|
|
||||||
position_ids = torch.tensor(
|
|
||||||
[
|
|
||||||
# First sequence with 2 samples
|
|
||||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
|
||||||
# Second sequence with 3 samples
|
|
||||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust as if this was the second slice (start_idx = 4)
|
|
||||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
|
||||||
|
|
||||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
|
||||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
|
||||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
|
||||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
|
||||||
|
|
||||||
assert torch.all(adjusted[0] == expected_first_seq)
|
|
||||||
assert torch.all(adjusted[1] == expected_second_seq)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRingAttention:
|
class TestRingAttention:
|
||||||
"""Tests for the ring attention functionality."""
|
"""Tests for the ring attention functionality."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user