diff --git a/codecov.yml b/codecov.yml index b4810bfa4..c85268b4c 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,5 +1,7 @@ codecov: require_ci_to_pass: yes + notify: + wait_for_ci: true coverage: precision: 2 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a51211263..44f8c5d2b 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -932,9 +932,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" - if issubclass(collator, DataCollatorForSeq2Seq): - kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree - kwargs["ring_attn_func"] = training_args.ring_attn_func return collator( *collator_args, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index fd72cd6db..3864903a5 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -371,13 +371,15 @@ class AxolotlTrainer( num_items_in_batch=num_items_in_batch, ) - return super().compute_loss( + loss = super().compute_loss( model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) + return loss + @staticmethod def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): concatenated_batch = {} diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 44751b465..6e4b3e4d0 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -6,4 +6,4 @@ from .optimizer import OptimizerMixin from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin -from .sequence_parallel import SequenceParallelMixin +from .sequence_parallel import SequenceParallelContextManager, SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 3930c6cb3..362acb88e 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -1,16 +1,86 @@ -"""Module for Axolotl trainer sequence parallelism mixin""" +""" +Module for Axolotl trainer sequence parallelism mixin and training context manager +""" +import functools import logging +import torch import torch.distributed as dist from datasets import Dataset +from torch import nn from torch.utils.data import DistributedSampler, Sampler +from torch.utils.hooks import RemovableHandle -from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group +from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, + get_ring_attn_group, + update_ring_attn_params, +) LOG = logging.getLogger(__name__) +def apply_sequence_parallelism( + batch: dict[str, torch.Tensor], + local_rank: int, + local_world_size: int, + ring_attn_func: RingAttnFunc, +) -> dict[str, torch.Tensor]: + """ + Apply sequence parallelism slicing to a batch. + + Args: + batch: Batch dictionary (e.g., input_ids, attention_mask, etc.) + local_rank: Local rank in the sequence parallel group + local_world_size: World size of the sequence parallel group + ring_attn_func: The ring attention function to use + + Returns: + Sliced batch dictionary. + """ + # Update ring attention params if needed + if batch.get("position_ids") is not None: + update_ring_attn_params(position_ids=batch["position_ids"]) + + # Slice batch for sequence parallel processing + total_seq_len = batch["input_ids"].size(1) + for key in batch: + if ( + key in batch + and isinstance(batch[key], torch.Tensor) + and batch[key].dim() > 1 + and batch[key].size(1) == total_seq_len + ): + + if ring_attn_func in [ + RingAttnFunc.VARLEN_LLAMA3, + RingAttnFunc.BATCH_RING, + ]: + # Split in sequential fashion and grab this rank's chunk + batch[key] = ( + batch[key].chunk(local_world_size, dim=1)[local_rank].contiguous() + ) + elif ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + chunks = batch[key].chunk(2 * local_world_size, dim=1) + + # Take rank's chunk and opposing chunk for zigzag pattern + selected_chunks = [ + chunks[local_rank], + chunks[2 * local_world_size - local_rank - 1], + ] + batch[key] = torch.cat(selected_chunks, dim=1).contiguous() + elif ring_attn_func is RingAttnFunc.BATCH_STRIPE: + # Split into striped data and stack + tensor = torch.stack( + batch[key].split(local_world_size, dim=1), + dim=1, + ).transpose(1, 2) + batch[key] = tensor[:, local_rank].contiguous() + + return batch + + class SequenceParallelMixin: """ Mixin class for sequence parallelism support in trainers. @@ -87,3 +157,157 @@ class SequenceParallelMixin: return self._create_sequence_parallel_sampler( eval_dataset, shuffle=False, is_eval=True ) + + +class SequenceParallelContextManager: + """ + Context manager for sequence parallelism operations. + + This class provides a context that will automatically apply sequence parallelism + during model forward passes using a pre-forward hook, and gather outputs from + across the sequence parallelism group using a post-forward hook. + """ + + def __init__( + self, + model: nn.Module, + sequence_parallel_degree: int, + ring_attn_func: RingAttnFunc, + ): + self.model = model + self.sequence_parallel_degree = sequence_parallel_degree + self.ring_attn_func = ring_attn_func + self.process_group = get_ring_attn_group() + + # Initialize sequence parallel group details + self.local_rank = dist.get_rank(self.process_group) + self.local_world_size = dist.get_world_size(self.process_group) + + # Will store hook handles for removal + self.hook_handles: list[RemovableHandle] = [] + + # Create a partially applied version of the apply_sequence_parallelism function + # with pre-configured params + self.apply_sequence_parallelism = functools.partial( + apply_sequence_parallelism, + local_rank=self.local_rank, + local_world_size=self.local_world_size, + ring_attn_func=self.ring_attn_func, + ) + + def __enter__(self): + # Forward pre-hook to apply sequence parallelism + def sequence_parallel_pre_hook(_, args, kwargs): + # Apply sequence parallelism to kwargs + kwargs = self.apply_sequence_parallelism(batch=kwargs) + return args, kwargs + + # Forward post-hook to gather outputs + def sequence_parallel_post_hook(_, __, output): + # Gather the sharded outputs + return self.gather_outputs(output) + + # Register both hooks + self.hook_handles.append( + self.model.register_forward_pre_hook( + sequence_parallel_pre_hook, with_kwargs=True + ) + ) + self.hook_handles.append( + self.model.register_forward_hook(sequence_parallel_post_hook) + ) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Remove all hooks + for handle in self.hook_handles: + handle.remove() + self.hook_handles = [] + + def gather_outputs(self, output): + """Gather sharded outputs from all ranks and reconstruct the full tensor.""" + # Handle different output formats (dict, tensor, etc.) + if isinstance(output, dict): + gathered_output = {} + for key, value in output.items(): + if isinstance(value, torch.Tensor) and value.dim() > 1: + # Gather logits or other sequence-sharded tensors + gathered_value = self.gather_tensor(value) + gathered_output[key] = gathered_value + else: + gathered_value = value.clone() + dist.all_reduce( + gathered_value, op=dist.ReduceOp.SUM, group=self.process_group + ) + gathered_output[key] = gathered_value + return gathered_output + if isinstance(output, torch.Tensor): + return self.gather_tensor(output) + + return output + + def gather_tensor(self, tensor): + """Gather a sharded tensor from all ranks.""" + # Prepare tensors for all_gather + world_size = self.local_world_size + + # Create list to store tensors from all ranks + gathered_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] + + # All-gather operation + dist.all_gather(gathered_tensors, tensor, group=self.process_group) + + # Concatenate along sequence dimension (typically dim=1) + if self.ring_attn_func in [RingAttnFunc.VARLEN_LLAMA3, RingAttnFunc.BATCH_RING]: + # Simple concatenation for standard sharding + return torch.cat(gathered_tensors, dim=1) + + if self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: + # Each rank has a pattern of (rank, world_size*2-rank-1) + reconstituted_tensors = [None] * (world_size * 2) + + # First, split each gathered tensor into its two chunks + for rank, gathered_tensor in enumerate(gathered_tensors): + # Each tensor contains two chunks in the sequence dimension + chunk_size = gathered_tensor.size(1) // 2 + chunk1, chunk2 = gathered_tensor.split(chunk_size, dim=1) + + # Place chunks in their original positions + reconstituted_tensors[rank] = chunk1 + reconstituted_tensors[world_size * 2 - rank - 1] = chunk2 + + # Concatenate the reconstituted tensors in the correct order + return torch.cat(reconstituted_tensors, dim=1) + + # Otherwise, RingAttnFunc.BATCH_STRIPE + # In striping, each rank has every world_size-th slice + batch_size = tensor.size(0) + hidden_dim = tensor.size(-1) + + # First, determine the full sequence length + total_seq_len = 0 + for t in gathered_tensors: + total_seq_len += t.size(1) + + # Create a tensor to hold the unstriped result + result = torch.zeros( + batch_size, + total_seq_len, + hidden_dim, + dtype=tensor.dtype, + device=tensor.device, + ) + + # For each rank's tensor, distribute its slices to the correct positions + for rank, gathered_tensor in enumerate(gathered_tensors): + # The rank's tensor contains every world_size-th slice + # starting from its rank position + seq_len = gathered_tensor.size(1) + for i in range(seq_len): + # Calculate the position in the full tensor + pos = i * world_size + rank + if pos < total_seq_len: + result[:, pos] = gathered_tensor[:, i] + + return result diff --git a/src/axolotl/train.py b/src/axolotl/train.py index e003c8b67..d116ea4fd 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -6,6 +6,7 @@ import os import signal import sys import weakref +from contextlib import nullcontext from pathlib import Path from typing import Any, Dict @@ -25,6 +26,9 @@ from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder +from axolotl.core.trainers.mixins.sequence_parallel import ( + SequenceParallelContextManager, +) from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import cleanup_distributed @@ -185,16 +189,28 @@ def execute_training( trainer: The configured trainer object. resume_from_checkpoint: Path to checkpoint to resume from, if applicable. """ - LOG.info("Starting trainer...") - if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel( - # TODO configure these from the YAML w/ sdp_kernel_kwargs: ... + # Define the context managers to use + flash_context = ( + torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True, - ): - trainer.train(resume_from_checkpoint=resume_from_checkpoint) - else: + ) + if cfg.flash_optimum + else nullcontext() + ) + sequence_parallel_context = ( + SequenceParallelContextManager( + model=trainer.model, + sequence_parallel_degree=cfg.sequence_parallel_degree, + ring_attn_func=cfg.ring_attn_func, + ) + if cfg.sequence_parallel_degree > 1 + else nullcontext() + ) + + LOG.info("Starting trainer...") + with flash_context, sequence_parallel_context: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index 738ef0dc5..45facf832 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,20 +1,12 @@ -""" -Data collators for axolotl to pad labels and position_ids for packed sequences. Also -includes logic for handling sequence parallelism collation. -""" +"""Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass from typing import Any import numpy as np -import torch -import torch.distributed as dist from transformers import PreTrainedTokenizerBase from transformers.utils import PaddingStrategy -from axolotl.monkeypatch.attention.ring_attn import update_ring_attn_params -from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - @dataclass class DataCollatorForSeq2Seq: @@ -49,8 +41,6 @@ class DataCollatorForSeq2Seq: The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). return_tensors (`str`): The type of Tensor to return. Allowable values are "np", "pt" and "tf". - sequence_parallel_degree (`int`): - The degree of sequence parallelism. Default to 1 for no sequence parallelism. """ tokenizer: PreTrainedTokenizerBase @@ -61,17 +51,6 @@ class DataCollatorForSeq2Seq: label_pad_token_id: int = -100 position_pad_token_id: int = 0 return_tensors: str = "pt" - sequence_parallel_degree: int = 1 - ring_attn_func: RingAttnFunc | None = None - - def __post_init__(self): - if self.sequence_parallel_degree > 1: - from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group - - # Get information about our position in the SP group - sp_group = get_ring_attn_group() - self.local_rank = dist.get_rank(group=sp_group) - self.local_world_size = dist.get_world_size(group=sp_group) def __call__(self, features, return_tensors=None): has_attn_mask = "attention_mask" in features[0].keys() @@ -141,62 +120,8 @@ class DataCollatorForSeq2Seq: ) features["decoder_input_ids"] = decoder_input_ids - if self.sequence_parallel_degree > 1: - features = self.apply_sequence_parallelism(features) - return features - def apply_sequence_parallelism( - self, batch: dict[str, torch.Tensor] - ) -> torch.Tensor: - """ - Apply sequence parallelism slicing to a batch. - - Args: - batch: Batch dictionary from parent collator. - - Returns: - Sliced batch dictionary. - """ - # Get local (start, end) for sequence parallelism slicing - total_seq_len = batch["input_ids"].size(1) - - # Update params for varlen ring attention calculation - if batch.get("position_ids") is not None: - update_ring_attn_params(position_ids=batch["position_ids"]) - - # Slice batch for sequence parallel processing - for key in batch: - if batch[key].size(1) == total_seq_len: - if self.ring_attn_func in [ - RingAttnFunc.VARLEN_LLAMA3, - RingAttnFunc.BATCH_RING, - ]: - batch[key] = ( - batch[key] - .chunk(self.local_world_size, dim=1)[self.local_rank] - .contiguous() - ) - elif self.ring_attn_func is RingAttnFunc.BATCH_ZIGZAG: - chunks = batch[key].chunk(2 * self.local_world_size, dim=1) - - # Take rank's chunk and opposing chunk for zigzag pattern - selected_chunks = [ - chunks[self.local_rank], - chunks[2 * self.local_world_size - self.local_rank - 1], - ] - batch[key] = torch.cat(selected_chunks, dim=1).contiguous() - elif self.ring_attn_func is RingAttnFunc.BATCH_STRIPE: - # TODO(djsaunde): This doesn't seem to work as expected - # Split into striped data and stack - tensor = torch.stack( - batch[key].split(self.local_world_size, dim=1), - dim=1, - ).transpose(1, 2) - batch[key] = tensor[:, self.local_rank].contiguous() - - return batch - @dataclass class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b527dce08..e5ea44aa0 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -126,9 +126,6 @@ def normalize_config(cfg): with open(ds_config_path, encoding="utf-8") as f: cfg.deepspeed = json.load(f) - if cfg.sequence_parallel_degree is None: - cfg.sequence_parallel_degree = 1 - if cfg.saves_per_epoch: save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) if save_steps < 1.0: # prevent saves on every step diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 732ae60cf..f68d160df 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -18,6 +18,7 @@ from pydantic import ( ) from transformers.utils.import_utils import is_torch_npu_available +from axolotl.utils.distributed import is_main_process from axolotl.utils.schemas.datasets import ( DatasetConfig, DPODataset, @@ -718,9 +719,10 @@ class AxolotlInputConfig( and data.get("eval_sample_packing") is None and not data.get("eval_table_size") ): - LOG.info( - "explicitly setting `eval_sample_packing` to match `sample_packing`" - ) + if is_main_process(): + LOG.info( + "explicitly setting `eval_sample_packing` to match `sample_packing`" + ) data["eval_sample_packing"] = True if ( @@ -1149,22 +1151,17 @@ class AxolotlInputConfig( return data - @field_validator("sequence_parallel_degree", mode="after") - @classmethod - def check_sequence_parallel_degree(cls, value, info): - if not value: - value = 1 - - if value > 1: - if not info.data.get("flash_attention"): + @model_validator(mode="after") + def check_sequence_parallel_degree(self): + if not self.sequence_parallel_degree: + self.sequence_parallel_degree = 1 + elif self.sequence_parallel_degree > 1: + if not self.flash_attention: raise ValueError( "flash_attention: true must be set with sequence_parallel_degree > 1" ) - if ( - info.data.get("sample_packing") - and not info.data["micro_batch_size"] == 1 - ): + if self.sample_packing and self.micro_batch_size > 1: raise ValueError( "micro_batch_size must be set to 1 when sample_packing is enabled" "due to a `ring-flash-attn` requirement" @@ -1182,44 +1179,43 @@ class AxolotlInputConfig( # TODO: monkeypatch / callback to average losses correctly across SP ranks # / fix gradient scaling across SP ranks. Losses, grads should be scaled # according to the proportion of non-padding tokens per rank. - LOG.warning( - "Sequence parallelism (SP) is enabled with " - f"sequence_parallel_degree={value}. Please note that logged losses may " - "differ slightly to the non-SP losses due to transformers Trainer " - "implementation details. Please see " - "https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " - "for more details." - ) + if is_main_process(): + LOG.warning( + "Sequence parallelism (SP) is enabled with " + f"sequence_parallel_degree={self.sequence_parallel_degree}. " + "Please note that logged losses may differ slightly to the non-SP " + "losses due to transformers Trainer implementation details. " + "Please see https://github.com/axolotl-ai-cloud/axolotl/pull/2495#issuecomment-2784022042 " + "for more details." + ) - return value + return self - @field_validator("ring_attn_func", mode="after") - @classmethod - def check_ring_attn_func(cls, value, info): - if not info.data.get("sequence_parallel_degree", 1) > 1: - return value + @model_validator(mode="after") + def validate_ring_attn_func(self): + if getattr(self, "sequence_parallel_degree", 1) == 1: + return self from axolotl.monkeypatch.attention.ring_attn.patch import RingAttnFunc - if value is not None: - # Set the ring attention function if passed in config + if self.ring_attn_func is not None: valid_funcs = list(RingAttnFunc) - if value in valid_funcs: - value = RingAttnFunc(value) + if self.ring_attn_func in valid_funcs: + self.ring_attn_func = RingAttnFunc(self.ring_attn_func) else: raise ValueError( - f"ring_attn_func: {value} must be one of {valid_funcs}" + f"ring_attn_func: {self.ring_attn_func} must be in {valid_funcs}" ) else: # Default ring attention function selection - sample_packing = info.data.get("sample_packing") - value = ( + sample_packing = getattr(self, "sample_packing", False) + self.ring_attn_func = ( RingAttnFunc.VARLEN_LLAMA3 if sample_packing else RingAttnFunc.BATCH_RING ) - return value + return self @model_validator(mode="before") @classmethod diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c1154be68..3dc9ae3f6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -348,7 +348,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): load_from_cache_file=not cfg.is_preprocess, desc="Add position_id column (PoSE)", ) - elif cfg.sample_packing or cfg.sequence_parallel_degree > 1: + elif cfg.sample_packing: drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" @@ -358,7 +358,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): **filter_map_kwargs, **drop_long_kwargs, ) - if cfg.eval_sample_packing or cfg.sequence_parallel_degree > 1: + if cfg.eval_sample_packing: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids, diff --git a/tests/e2e/patched/test_mixtral_samplepack.py b/tests/e2e/patched/test_mixtral_samplepack.py index 2d4f97084..f035b1f28 100644 --- a/tests/e2e/patched/test_mixtral_samplepack.py +++ b/tests/e2e/patched/test_mixtral_samplepack.py @@ -99,6 +99,7 @@ class TestMixtral(unittest.TestCase): "bf16": "auto", } ) + cfg = validate_config(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) diff --git a/tests/e2e/patched/test_sp.py b/tests/e2e/patched/test_sp.py index 70a601f63..6e1e2f2cb 100644 --- a/tests/e2e/patched/test_sp.py +++ b/tests/e2e/patched/test_sp.py @@ -2,14 +2,19 @@ # pylint: disable=redefined-outer-name,unused-argument +import functools +import sys from unittest.mock import MagicMock, patch import pytest import torch from accelerate.state import PartialState +from axolotl.core.trainers.mixins.sequence_parallel import apply_sequence_parallelism from axolotl.monkeypatch.attention.ring_attn import ( + RingAttnFunc, get_ring_attn_group, + register_ring_attn, set_ring_attn_group, ) from axolotl.utils.dict import DictDefault @@ -47,6 +52,27 @@ def fixture_cfg(): return cfg +@pytest.fixture +def sequence_parallel_batch(): + """Create a test batch for sequence parallelism tests.""" + batch_size = 1 + seq_len = 8 + + # Create test tensors + input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len) + attention_mask = torch.ones(batch_size, seq_len) + position_ids = torch.arange(seq_len).expand(batch_size, seq_len) + + # Create test batch + batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + return batch + + class TestRingAttention: """Tests for the ring attention functionality.""" @@ -73,11 +99,6 @@ class TestRingAttention: self, mock_world_size, mock_rank, mock_new_group, partial_state ): """Test that ring attention groups are created correctly.""" - from axolotl.monkeypatch.attention.ring_attn import ( - RingAttnFunc, - register_ring_attn, - ) - # Setup mocks mock_world_size.return_value = 8 # 8 GPUs total mock_rank.return_value = 3 # GPU #3 @@ -101,88 +122,308 @@ class TestRingAttention: set_ring_attn_group(None) -# Mock a simplified DataCollator test -@patch("axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group") -@patch("torch.distributed.get_rank") -@patch("torch.distributed.get_world_size") -def test_sequence_parallel_slicing( - mock_world_size, mock_rank, mock_get_group, partial_state -): - """Test the basic sequence slicing logic without full collator instantiation.""" - # Setup mocks - mock_get_group.return_value = MagicMock() - mock_rank.return_value = 1 # Second GPU - mock_world_size.return_value = 4 # 4 GPUs total +class TestConfigValidation: + """Tests for validating sequence parallelism configurations.""" - # Create a sample batch - batch = { - "input_ids": torch.tensor( - [ - [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112], - [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212], - ] - ), - "attention_mask": torch.ones(2, 12), - } + @pytest.fixture(autouse=True) + def setup_mocks(self, monkeypatch): + """Set up mocks for all tests in this class.""" + # Mock the ring_flash_attn module + monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock()) - # Simplified slicing logic from SequenceParallelDataCollator - def slice_batch(batch, rank, world_size): - result = {} - for key in batch: - seq_len = batch[key].shape[1] - slice_size = seq_len // world_size - start_idx = rank * slice_size - end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len - result[key] = batch[key][:, start_idx:end_idx] - return result + # Mock the is_main_process function to return True + monkeypatch.setattr( + "axolotl.utils.schemas.config.is_main_process", lambda: True + ) - # Slice the batch - result = slice_batch( - batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value - ) + @pytest.fixture + def base_cfg(self): + """Create a base configuration for testing.""" + return DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-3, + "output_dir": "./model-out", + "sequence_len": 512, + "special_tokens": {"pad_token": "<|endoftext|>"}, + } + ) - # Check slicing - assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU - expected_input_ids = torch.tensor( + @pytest.mark.parametrize( + "config_updates, expected_values, should_pass, error_msg", [ - [104, 105, 106], # Second slice of first sequence - [204, 205, 206], # Second slice of second sequence - ] + # Valid configuration + ( + {"sequence_parallel_degree": 2, "flash_attention": True}, + {"sequence_parallel_degree": 2, "flash_attention": True}, + True, + None, + ), + # Default sequence_parallel_degree + ({}, {"sequence_parallel_degree": 1}, True, None), + # Invalid: sequence_parallel_degree > 1 without flash_attention + ( + {"sequence_parallel_degree": 2, "flash_attention": False}, + None, + False, + "flash_attention: true must be set", + ), + # Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1 + ( + { + "sequence_parallel_degree": 2, + "flash_attention": True, + "sample_packing": True, + "micro_batch_size": 2, + "pad_to_sequence_len": True, + }, + None, + False, + "micro_batch_size must be set to 1", + ), + ], + ids=[ + "valid_config", + "default_sp_degree", + "without_flash_attention", + "sample_packing_with_large_batch", + ], ) - assert torch.all(result["input_ids"] == expected_input_ids) + def test_sequence_parallel_config_validation( + self, base_cfg, config_updates, expected_values, should_pass, error_msg + ): + """Test various sequence parallelism configuration scenarios.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Apply updates to base config + cfg = base_cfg + cfg.update(config_updates) + + if should_pass: + # Should validate without errors + config = AxolotlInputConfig(**cfg) + + # Check expected values + for key, value in expected_values.items(): + assert getattr(config, key) == value + else: + # Should raise exception + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + assert error_msg in str(excinfo.value) + + @pytest.mark.parametrize( + "ring_attn_func, sample_packing, expected_func", + [ + (None, True, RingAttnFunc.VARLEN_LLAMA3), + (None, False, RingAttnFunc.BATCH_RING), + ], + ids=["default_with_sample_packing", "default_without_sample_packing"], + ) + def test_ring_attn_func_validation( + self, base_cfg, ring_attn_func, sample_packing, expected_func + ): + """Test ring_attn_func validation and defaults.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Apply updates to base config + cfg = base_cfg | { + "sequence_parallel_degree": 2, + "flash_attention": True, + "sample_packing": sample_packing, + } + + if ring_attn_func is not None: + cfg["ring_attn_func"] = ring_attn_func + + # Should validate without errors + config = AxolotlInputConfig(**cfg) + + # Check ring_attn_func value + assert config.ring_attn_func.value == expected_func + + def test_invalid_ring_attn_func(self, base_cfg): + """Test that an invalid ring_attn_func is rejected.""" + from axolotl.utils.schemas.config import AxolotlInputConfig + + # Invalid configuration with invalid ring_attn_func + cfg = base_cfg | { + "sequence_parallel_degree": 2, + "flash_attention": True, + "ring_attn_func": "INVALID_FUNC", + } + + # Should raise ValidationError + with pytest.raises(ValueError) as excinfo: + AxolotlInputConfig(**cfg) + + # Verify error message + assert "ring_attn_func: INVALID_FUNC must be in" in str(excinfo.value) -@patch.dict("sys.modules", {"ring_flash_attn": MagicMock()}) -def test_config_validation_with_valid_inputs(cfg): - """Test that valid sequence parallelism configurations pass validation.""" - # Import the actual model class with appropriate mocks - from axolotl.utils.schemas.config import AxolotlInputConfig +class TestApplySequenceParallelism: + """Tests for the apply_sequence_parallelism function.""" - # Valid configuration: sequence_parallel_degree > 1 and flash_attention is True - cfg = cfg | { - "sequence_parallel_degree": 2, - "flash_attention": True, - } + @pytest.fixture(autouse=True) + def mock_distributed(self, monkeypatch): + """Mock torch.distributed functions for testing.""" + # Mock is_initialized to return True + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) - # Should validate without errors - config = AxolotlInputConfig(**cfg) - assert config.sequence_parallel_degree == 2 - assert config.flash_attention is True + # Mock get_rank to return 0 by default + monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0) + # Mock get_world_size to return 2 by default + monkeypatch.setattr( + torch.distributed, "get_world_size", lambda *args, **kwargs: 2 + ) -def test_config_validation_with_invalid_inputs(cfg): - """Test that invalid sequence parallelism configurations fail validation.""" - from axolotl.utils.schemas.config import AxolotlInputConfig + # Mock the process group + monkeypatch.setattr( + "axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group", + MagicMock, + ) - # Invalid configuration: sequence_parallel_degree > 1 but flash_attention is False - cfg = cfg | { - "sequence_parallel_degree": 2, - "flash_attention": False, - } + # Mock update_ring_attn_params + monkeypatch.setattr( + "axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params", + lambda **kwargs: None, + ) - # Should raise ValidationError - with pytest.raises(ValueError) as excinfo: - AxolotlInputConfig(**cfg) + def test_world_size_one(self, sequence_parallel_batch): + """Test that function returns original batch when world size is 1.""" + result = apply_sequence_parallelism( + batch=sequence_parallel_batch, + local_rank=0, + local_world_size=1, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) - # Verify error message - assert "flash_attention: true must be set" in str(excinfo.value) + # Should return the original batch unchanged + assert result == sequence_parallel_batch + + def test_batch_ring_rank0(self, sequence_parallel_batch): + """Test BATCH_RING sharding for rank 0 in a 2-process group.""" + batch = sequence_parallel_batch + seq_len = batch["input_ids"].size(1) + + result = apply_sequence_parallelism( + batch=batch, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Check that sequence dimension was sharded correctly + assert result["input_ids"].shape[1] == seq_len // 2 + assert result["attention_mask"].shape[1] == seq_len // 2 + + # Verify content: rank 0 should get the first half of the sequence + assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2]) + assert torch.equal( + result["position_ids"], batch["position_ids"][:, : seq_len // 2] + ) + + def test_batch_ring_rank1(self, sequence_parallel_batch): + """Test BATCH_RING sharding for rank 1 in a 2-process group.""" + batch = sequence_parallel_batch + seq_len = batch["input_ids"].size(1) + original_input_ids = batch["input_ids"].clone() + + result = apply_sequence_parallelism( + batch=batch, + local_rank=1, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Verify content: rank 1 should get the second half of the sequence + assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :]) + + def test_batch_zigzag(self, sequence_parallel_batch): + """Test BATCH_ZIGZAG sharding pattern.""" + batch = sequence_parallel_batch + original_input_ids = batch["input_ids"].clone() + seq_len = batch["input_ids"].size(1) + + # Test rank 0 + result_rank0 = apply_sequence_parallelism( + batch={k: v.clone() for k, v in batch.items()}, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + ) + + # Test rank 1 + result_rank1 = apply_sequence_parallelism( + batch={k: v.clone() for k, v in batch.items()}, + local_rank=1, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_ZIGZAG, + ) + + # Checks for both ranks + assert result_rank0["input_ids"].shape[1] == seq_len // 2 + assert result_rank1["input_ids"].shape[1] == seq_len // 2 + + # For a 2-rank system with 8 tokens, check specific zigzag pattern + # Rank 0 should get chunks [0, 1] and [6, 7] + # Rank 1 should get chunks [2, 3] and [4, 5] + if seq_len == 8: + # Create expected tensors for comparison + rank0_expected = torch.cat( + [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1 + ) + + rank1_expected = torch.cat( + [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1 + ) + + assert torch.equal(result_rank0["input_ids"], rank0_expected) + assert torch.equal(result_rank1["input_ids"], rank1_expected) + + def test_partial_application(self, sequence_parallel_batch): + """Test that we can create a partially applied version of the function.""" + batch = sequence_parallel_batch + original_input_ids = batch["input_ids"].clone() + + # Create a partially applied function + rank0_ring_parallel = functools.partial( + apply_sequence_parallelism, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Use the partially applied function + result = rank0_ring_parallel(batch=batch) + + # Verify it works as expected + assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2 + assert torch.equal( + result["input_ids"], + original_input_ids[:, : original_input_ids.shape[1] // 2], + ) + + def test_missing_position_ids(self, sequence_parallel_batch): + """Test handling of batch without position_ids.""" + # Create a batch without position_ids + batch = { + k: v for k, v in sequence_parallel_batch.items() if k != "position_ids" + } + original_input_ids = batch["input_ids"].clone() + + # This should run without error even though position_ids is missing + result = apply_sequence_parallelism( + batch=batch, + local_rank=0, + local_world_size=2, + ring_attn_func=RingAttnFunc.BATCH_RING, + ) + + # Verification should pass + assert "position_ids" not in result + assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2