pre-commit fix
This commit is contained in:
@@ -1155,7 +1155,7 @@ class AxolotlInputConfig(
|
||||
raise ValueError(
|
||||
"flash_attention: true must be set with sequence_parallel_degree > 1"
|
||||
)
|
||||
|
||||
|
||||
if not info.data["micro_batch_size"] == 1:
|
||||
raise ValueError(
|
||||
"micro_batch_size must be set to 1 "
|
||||
|
||||
@@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -48,33 +47,6 @@ def fixture_cfg():
|
||||
return cfg
|
||||
|
||||
|
||||
class TestSequenceParallelHelpers:
|
||||
"""Test helper functions used in sequence parallelism."""
|
||||
|
||||
def test_adjust_position_ids_for_slice(self, partial_state):
|
||||
"""Test position_ids adjustment for sequence slices."""
|
||||
# Create sample position_ids with multiple sequences
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
# First sequence with 2 samples
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 3],
|
||||
# Second sequence with 3 samples
|
||||
[0, 1, 2, 0, 1, 2, 3, 0, 1],
|
||||
]
|
||||
)
|
||||
|
||||
# Adjust as if this was the second slice (start_idx = 4)
|
||||
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
|
||||
|
||||
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
|
||||
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
|
||||
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
|
||||
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
|
||||
|
||||
assert torch.all(adjusted[0] == expected_first_seq)
|
||||
assert torch.all(adjusted[1] == expected_second_seq)
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user