From 2f3c52ea2f321695831331e62e81c36836cd281d Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 6 Apr 2025 00:36:27 +0000 Subject: [PATCH] pre-commit fix --- src/axolotl/utils/schemas/config.py | 2 +- tests/e2e/patched/test_sp.py | 28 ---------------------------- 2 files changed, 1 insertion(+), 29 deletions(-) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index b044b280e..42ece6023 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1155,7 +1155,7 @@ class AxolotlInputConfig( raise ValueError( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - + if not info.data["micro_batch_size"] == 1: raise ValueError( "micro_batch_size must be set to 1 " diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 70beb8a54..1361a8522 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -12,7 +12,6 @@ from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group, set_ring_attn_group, ) -from axolotl.utils.collators.batching import adjust_position_ids_for_slice from axolotl.utils.dict import DictDefault @@ -48,33 +47,6 @@ def fixture_cfg(): return cfg -class TestSequenceParallelHelpers: - """Test helper functions used in sequence parallelism.""" - - def test_adjust_position_ids_for_slice(self, partial_state): - """Test position_ids adjustment for sequence slices.""" - # Create sample position_ids with multiple sequences - position_ids = torch.tensor( - [ - # First sequence with 2 samples - [0, 1, 2, 3, 4, 0, 1, 2, 3], - # Second sequence with 3 samples - [0, 1, 2, 0, 1, 2, 3, 0, 1], - ] - ) - - # Adjust as if this was the second slice (start_idx = 4) - adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4) - - # For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1] - # For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3] - expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4 - expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4 - - assert torch.all(adjusted[0] == expected_first_seq) - assert torch.all(adjusted[1] == expected_second_seq) - - class TestRingAttention: """Tests for the ring attention functionality."""