diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..6e4b3e4d0 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,4 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin +from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index b5101e035..362acb88e 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -2,6 +2,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manager """ +import functools import logging import torch @@ -20,6 +21,66 @@ from axolotl.monkeypatch.attention.ring_attn import ( LOG = logging.getLogger(__name__) +def apply_sequence_parallelism( + batch: dict[str, torch.Tensor], + local_rank: int, + local_world_size: int, + ring_attn_func: RingAttnFunc, +) -> dict[str, torch.Tensor]: + """ + Apply sequence parallelism slicing to a batch. + + Args: + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) + local_rank: Local rank in the sequence parallel group + local_world_size: World size of the sequence parallel group + ring_attn_func: The ring attention function to use + + Returns: + Sliced batch dictionary. + """ + # Update ring attention params if needed + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) + + # Slice batch for sequence parallel processing + total_seq_len = batch["input_ids"].size(1) + for key in batch: + if ( + key in batch + and isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len + ): + + if ring_attn_func in [ + RingAttnFunc.VARLEN_LLAMA3, + RingAttnFunc.BATCH_RING, + ]: + # Split in sequential fashion and grab this rank's chunk + batch[key] = ( + batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() + ) + elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + chunks = batch[key].chunk(2 * local_world_size, dim=1) + + # Take rank's chunk and opposing chunk for zigzag pattern + selected_chunks = [ + chunks[local_rank], + chunks[2 * local_world_size - local_rank - 1], + ] + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() + elif ring_attn_func is RingAttnFunc.BATCH_STRIPE: + # Split into striped data and stack + tensor = torch.stack( + batch[key].split(local_world_size, dim=1), + dim=1, + ).transpose(1, 2) + batch[key] = tensor[:, local_rank].contiguous() + + return batch + + class SequenceParallelMixin: """ Mixin class for sequence parallelism support in trainers. @@ -125,11 +186,20 @@ class SequenceParallelContextManager: # Will store hook handles for removal self.hook_handles: list[RemovableHandle] = [] + # Create a partially applied version of the apply_sequence_parallelism function + # with pre-configured params + self.apply_sequence_parallelism = functools.partial( + apply_sequence_parallelism, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + ring_attn_func=self.ring_attn_func, + ) + def __enter__(self): # Forward pre-hook to apply sequence parallelism def sequence_parallel_pre_hook(_, args, kwargs): # Apply sequence parallelism to kwargs - kwargs = self.apply_sequence_parallelism(kwargs) + kwargs = self.apply_sequence_parallelism(batch=kwargs) return args, kwargs # Forward post-hook to gather outputs @@ -155,61 +225,6 @@ class SequenceParallelContextManager: handle.remove() self.hook_handles = [] - def apply_sequence_parallelism( - self, batch: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) - - Returns: - Sliced batch dictionary. - """ - # Update ring attention params if needed - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - total_seq_len = batch["input_ids"].size(1) - for key in batch: - if ( - key in batch - and isinstance(batch[key], torch.Tensor) - and batch[key].dim() > 1 - and batch[key].size(1) == total_seq_len - ): - - if self.ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - # Split in sequential fashion and grab this rank's chunk - batch[key] = ( - batch[key] - .chunk(self.local_world_size, dim=1)[self.local_rank] - .contiguous() - ) - elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * self.local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[self.local_rank], - chunks[2 * self.local_world_size - self.local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(self.local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, self.local_rank].contiguous() - - return batch - def gather_outputs(self, output): """Gather sharded outputs from all ranks and reconstruct the full tensor.""" # Handle different output formats (dict, tensor, etc.) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 70a601f63..6e1e2f2cb 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -2,14 +2,19 @@ # pylint: disable=redefined-outer-name,unused-argument +import functools +import sys from unittest.mock import MagicMock, patch import pytest import torch from accelerate.state import PartialState +from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, get_ring_attn_group, + register_ring_attn, set_ring_attn_group, ) from axolotl.utils.dict import DictDefault @@ -47,6 +52,27 @@ def fixture_cfg(): return cfg +@pytest.fixture +def sequence_parallel_batch(): + """Create a test batch for sequence parallelism tests.""" + batch_size = 1 + seq_len = 8 + + # Create test tensors + input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len) + attention_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).expand(batch_size, seq_len) + + # Create test batch + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + return batch + + class TestRingAttention: """Tests for the ring attention functionality.""" @@ -73,11 +99,6 @@ class TestRingAttention: self, mock_world_size, mock_rank, mock_new_group, partial_state ): """Test that ring attention groups are created correctly.""" - from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, - register_ring_attn, - ) - # Setup mocks mock_world_size.return_value = 8 # 8 GPUs total mock_rank.return_value = 3 # GPU #3 @@ -101,88 +122,308 @@ class TestRingAttention: set_ring_attn_group(None) -# Mock a simplified DataCollator test -@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group") -@patch("torch.distributed.get_rank") -@patch("torch.distributed.get_world_size") -def test_sequence_parallel_slicing( - mock_world_size, mock_rank, mock_get_group, partial_state -): - """Test the basic sequence slicing logic without full collator instantiation.""" - # Setup mocks - mock_get_group.return_value = MagicMock() - mock_rank.return_value = 1 # Second GPU - mock_world_size.return_value = 4 # 4 GPUs total +class TestConfigValidation: + """Tests for validating sequence parallelism configurations.""" - # Create a sample batch - batch = { - "input_ids": torch.tensor( - [ - [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112], - [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212], - ] - ), - "attention_mask": torch.ones(2, 12), - } + @pytest.fixture(autouse=True) + def setup_mocks(self, monkeypatch): + """Set up mocks for all tests in this class.""" + # Mock the ring_flash_attn module + monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) - # Simplified slicing logic from SequenceParallelDataCollator - def slice_batch(batch, rank, world_size): - result = {} - for key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - result[key] = batch[key][:, start_idx:end_idx] - return result + # Mock the is_main_process function to return True + monkeypatch.setattr( + "axolotl.utils.schemas.config.is_main_process", lambda: True + ) - # Slice the batch - result = slice_batch( - batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value - ) + @pytest.fixture + def base_cfg(self): + """Create a base configuration for testing.""" + return DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-3, + "output_dir": "./model-out", + "sequence_len": 512, + "special_tokens": {"pad_token": "<|endoftext|>"}, + } + ) - # Check slicing - assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU - expected_input_ids = torch.tensor( + @pytest.mark.parametrize( + "config_updates, expected_values, should_pass, error_msg", [ - [104, 105, 106], # Second slice of first sequence - [204, 205, 206], # Second slice of second sequence - ] + # Valid configuration + ( + {"sequence_parallel_degree": 2, "flash_attention": True}, + {"sequence_parallel_degree": 2, "flash_attention": True}, + True, + None, + ), + # Default sequence_parallel_degree + ({}, {"sequence_parallel_degree": 1}, True, None), + # Invalid: sequence_parallel_degree > 1 without flash_attention + ( + {"sequence_parallel_degree": 2, "flash_attention": False}, + None, + False, + "flash_attention: true must be set", + ), + # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 + ( + { + "sequence_parallel_degree": 2, + "flash_attention": True, + "sample_packing": True, + "micro_batch_size": 2, + "pad_to_sequence_len": True, + }, + None, + False, + "micro_batch_size must be set to 1", + ), + ], + ids=[ + "valid_config", + "default_sp_degree", + "without_flash_attention", + "sample_packing_with_large_batch", + ], ) - assert torch.all(result["input_ids"] == expected_input_ids) + def test_sequence_parallel_config_validation( + self, base_cfg, config_updates, expected_values, should_pass, error_msg + ): + """Test various sequence parallelism configuration scenarios.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Apply updates to base config + cfg = base_cfg + cfg.update(config_updates) + + if should_pass: + # Should validate without errors + config = AxolotlInputConfig(**cfg) + + # Check expected values + for key, value in expected_values.items(): + assert getattr(config, key) == value + else: + # Should raise exception + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + assert error_msg in str(excinfo.value) + + @pytest.mark.parametrize( + "ring_attn_func, sample_packing, expected_func", + [ + (None, True, RingAttnFunc.VARLEN_LLAMA3), + (None, False, RingAttnFunc.BATCH_RING), + ], + ids=["default_with_sample_packing", "default_without_sample_packing"], + ) + def test_ring_attn_func_validation( + self, base_cfg, ring_attn_func, sample_packing, expected_func + ): + """Test ring_attn_func validation and defaults.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Apply updates to base config + cfg = base_cfg | { + "sequence_parallel_degree": 2, + "flash_attention": True, + "sample_packing": sample_packing, + } + + if ring_attn_func is not None: + cfg["ring_attn_func"] = ring_attn_func + + # Should validate without errors + config = AxolotlInputConfig(**cfg) + + # Check ring_attn_func value + assert config.ring_attn_func.value == expected_func + + def test_invalid_ring_attn_func(self, base_cfg): + """Test that an invalid ring_attn_func is rejected.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Invalid configuration with invalid ring_attn_func + cfg = base_cfg | { + "sequence_parallel_degree": 2, + "flash_attention": True, + "ring_attn_func": "INVALID_FUNC", + } + + # Should raise ValidationError + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + + # Verify error message + assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value) -@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()}) -def test_config_validation_with_valid_inputs(cfg): - """Test that valid sequence parallelism configurations pass validation.""" - # Import the actual model class with appropriate mocks - from axolotl.utils.schemas.config import AxolotlInputConfig +class TestApplySequenceParallelism: + """Tests for the apply_sequence_parallelism function.""" - # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True - cfg = cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - } + @pytest.fixture(autouse=True) + def mock_distributed(self, monkeypatch): + """Mock torch.distributed functions for testing.""" + # Mock is_initialized to return True + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) - # Should validate without errors - config = AxolotlInputConfig(**cfg) - assert config.sequence_parallel_degree == 2 - assert config.flash_attention is True + # Mock get_rank to return 0 by default + monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) + # Mock get_world_size to return 2 by default + monkeypatch.setattr( + torch.distributed, "get_world_size", lambda *args, **kwargs: 2 + ) -def test_config_validation_with_invalid_inputs(cfg): - """Test that invalid sequence parallelism configurations fail validation.""" - from axolotl.utils.schemas.config import AxolotlInputConfig + # Mock the process group + monkeypatch.setattr( + "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", + MagicMock, + ) - # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False - cfg = cfg | { - "sequence_parallel_degree": 2, - "flash_attention": False, - } + # Mock update_ring_attn_params + monkeypatch.setattr( + "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", + lambda **kwargs: None, + ) - # Should raise ValidationError - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) + def test_world_size_one(self, sequence_parallel_batch): + """Test that function returns original batch when world size is 1.""" + result = apply_sequence_parallelism( + batch=sequence_parallel_batch, + local_rank=0, + local_world_size=1, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) - # Verify error message - assert "flash_attention: true must be set" in str(excinfo.value) + # Should return the original batch unchanged + assert result == sequence_parallel_batch + + def test_batch_ring_rank0(self, sequence_parallel_batch): + """Test BATCH_RING sharding for rank 0 in a 2-process group.""" + batch = sequence_parallel_batch + seq_len = batch["input_ids"].size(1) + + result = apply_sequence_parallelism( + batch=batch, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Check that sequence dimension was sharded correctly + assert result["input_ids"].shape[1] == seq_len // 2 + assert result["attention_mask"].shape[1] == seq_len // 2 + + # Verify content: rank 0 should get the first half of the sequence + assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2]) + assert torch.equal( + result["position_ids"], batch["position_ids"][:, : seq_len // 2] + ) + + def test_batch_ring_rank1(self, sequence_parallel_batch): + """Test BATCH_RING sharding for rank 1 in a 2-process group.""" + batch = sequence_parallel_batch + seq_len = batch["input_ids"].size(1) + original_input_ids = batch["input_ids"].clone() + + result = apply_sequence_parallelism( + batch=batch, + local_rank=1, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Verify content: rank 1 should get the second half of the sequence + assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) + + def test_batch_zigzag(self, sequence_parallel_batch): + """Test BATCH_ZIGZAG sharding pattern.""" + batch = sequence_parallel_batch + original_input_ids = batch["input_ids"].clone() + seq_len = batch["input_ids"].size(1) + + # Test rank 0 + result_rank0 = apply_sequence_parallelism( + batch={k: v.clone() for k, v in batch.items()}, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + ) + + # Test rank 1 + result_rank1 = apply_sequence_parallelism( + batch={k: v.clone() for k, v in batch.items()}, + local_rank=1, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + ) + + # Checks for both ranks + assert result_rank0["input_ids"].shape[1] == seq_len // 2 + assert result_rank1["input_ids"].shape[1] == seq_len // 2 + + # For a 2-rank system with 8 tokens, check specific zigzag pattern + # Rank 0 should get chunks [0, 1] and [6, 7] + # Rank 1 should get chunks [2, 3] and [4, 5] + if seq_len == 8: + # Create expected tensors for comparison + rank0_expected = torch.cat( + [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 + ) + + rank1_expected = torch.cat( + [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 + ) + + assert torch.equal(result_rank0["input_ids"], rank0_expected) + assert torch.equal(result_rank1["input_ids"], rank1_expected) + + def test_partial_application(self, sequence_parallel_batch): + """Test that we can create a partially applied version of the function.""" + batch = sequence_parallel_batch + original_input_ids = batch["input_ids"].clone() + + # Create a partially applied function + rank0_ring_parallel = functools.partial( + apply_sequence_parallelism, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Use the partially applied function + result = rank0_ring_parallel(batch=batch) + + # Verify it works as expected + assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 + assert torch.equal( + result["input_ids"], + original_input_ids[:, : original_input_ids.shape[1] // 2], + ) + + def test_missing_position_ids(self, sequence_parallel_batch): + """Test handling of batch without position_ids.""" + # Create a batch without position_ids + batch = { + k: v for k, v in sequence_parallel_batch.items() if k != "position_ids" + } + original_input_ids = batch["input_ids"].clone() + + # This should run without error even though position_ids is missing + result = apply_sequence_parallelism( + batch=batch, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Verification should pass + assert "position_ids" not in result + assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2