pre-commit fix

This commit is contained in:
Dan Saunders
2025-04-06 00:36:27 +00:00
parent 741015b3cf
commit 2f3c52ea2f
2 changed files with 1 additions and 29 deletions

View File

@@ -1155,7 +1155,7 @@ class AxolotlInputConfig(
raise ValueError(
"flash_attention: true must be set with sequence_parallel_degree > 1"
)
if not info.data["micro_batch_size"] == 1:
raise ValueError(
"micro_batch_size must be set to 1 "

View File

@@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
set_ring_attn_group,
)
from axolotl.utils.collators.batching import adjust_position_ids_for_slice
from axolotl.utils.dict import DictDefault
@@ -48,33 +47,6 @@ def fixture_cfg():
return cfg
class TestSequenceParallelHelpers:
"""Test helper functions used in sequence parallelism."""
def test_adjust_position_ids_for_slice(self, partial_state):
"""Test position_ids adjustment for sequence slices."""
# Create sample position_ids with multiple sequences
position_ids = torch.tensor(
[
# First sequence with 2 samples
[0, 1, 2, 3, 4, 0, 1, 2, 3],
# Second sequence with 3 samples
[0, 1, 2, 0, 1, 2, 3, 0, 1],
]
)
# Adjust as if this was the second slice (start_idx = 4)
adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4)
# For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1]
# For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3]
expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4
expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4
assert torch.all(adjusted[0] == expected_first_seq)
assert torch.all(adjusted[1] == expected_second_seq)
class TestRingAttention:
"""Tests for the ring attention functionality."""