From 51c326150b1b540da2ec34670621d13ebde092ff Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 6 Mar 2025 16:25:53 +0000 Subject: [PATCH] pytest --- src/axolotl/core/trainer_builder.py | 22 +- src/axolotl/core/trainers/base.py | 86 ++-- .../attention}/ring_attn.py | 26 +- src/axolotl/utils/collators/__init__.py | 5 + .../utils/collators/sequence_parallel.py | 433 ++++++++++++++++++ src/axolotl/utils/models.py | 2 +- .../e2e/multigpu/test_sequence_parallelism.py | 114 +++++ .../e2e/patched/test_sequence_parallelism.py | 221 +++++++++ 8 files changed, 863 insertions(+), 46 deletions(-) rename src/axolotl/{utils => monkeypatch/attention}/ring_attn.py (66%) create mode 100644 src/axolotl/utils/collators/sequence_parallel.py create mode 100644 tests/e2e/multigpu/test_sequence_parallelism.py create mode 100644 tests/e2e/patched/test_sequence_parallelism.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a1b2ac27a..2e2883c89 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -80,7 +80,10 @@ from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, + SequenceParallelDataCollator, + SequenceParallelPackedDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, + V2SequenceParallelPackedDataCollator, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype @@ -880,15 +883,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if "max_length" in kwargs: kwargs.pop("max_length") elif use_batch_sampler_collator: - if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: - collator = V2BatchSamplerDataCollatorForSeq2Seq - elif ( + if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( self.cfg.model_config_type in ["llama"] and self.cfg.flash_attention is not True ): - collator = V2BatchSamplerDataCollatorForSeq2Seq + if self.cfg.sequence_parallel_size > 1: + collator = V2SequenceParallelPackedDataCollator + else: + collator = V2BatchSamplerDataCollatorForSeq2Seq else: - collator = BatchSamplerDataCollatorForSeq2Seq + if self.cfg.sequence_parallel_size > 1: + collator = SequenceParallelPackedDataCollator + else: + collator = BatchSamplerDataCollatorForSeq2Seq else: if self.cfg.processor_type and self.processor: collator = MultiModalChatDataCollator @@ -910,7 +917,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): else: collator = DataCollatorForKD else: - collator = DataCollatorForSeq2Seq + if self.cfg.sequence_parallel_size > 1: + collator = SequenceParallelDataCollator + else: + collator = DataCollatorForSeq2Seq kwargs["return_tensors"] = "pt" diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 4a1ba5a02..2f61a6da4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -1,6 +1,4 @@ -""" -module for customized trainers -""" +"""Module for customized trainers.""" from __future__ import annotations @@ -12,6 +10,7 @@ from functools import wraps from typing import Any, Dict, Literal, Optional import torch +import torch.distributed as dist import torch.nn.functional as F from datasets import Dataset from peft.optimizers import create_loraplus_optimizer @@ -27,8 +26,8 @@ from trl.trainer.utils import pad_to_length from typing_extensions import override from axolotl.integrations.base import BaseOptimizerFactory +from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group from axolotl.monkeypatch.relora import ReLoRAScheduler -from axolotl.utils.ring_attn import get_ring_attn_group from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.schedulers import ( RexLR, @@ -40,7 +39,7 @@ from axolotl.utils.schedulers import ( if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp -LOG = logging.getLogger("axolotl.core.trainer_builder") +LOG = logging.getLogger(__name__) def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): @@ -810,40 +809,57 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer): ) -> torch.Tensor: """ Perform a training step on a batch of inputs. - - Note: we are subclassing `transformers.trainer.Trainer` in order to compute - parameters needed for the ring flash attention implementation we're using. - - Args: - model (`nn.Module`): - The model to train. - inputs (`Dict[str, Union[torch.Tensor, Any]]`): - The inputs and targets of the model. - - The dictionary will be unpacked before being fed to the model. Most models expect the targets under the - argument `labels`. Check your model's documentation for all accepted arguments. - - Return: - `torch.Tensor`: The tensor with training loss on this batch. """ if self.args.sequence_parallel_size > 1: - if "attention_mask" in inputs: - # Calculate sequence lengths from attention mask - seq_lens = inputs["attention_mask"].sum(dim=1).tolist() - total_seq_len = ( - inputs["attention_mask"].shape[0] - * inputs["attention_mask"].shape[1] - ) - else: - # Assume all sequences are the same length if no mask is provided - batch_size = inputs["input_ids"].shape[0] - seq_len = inputs["input_ids"].shape[1] - seq_lens = [seq_len] * batch_size - total_seq_len = batch_size * seq_len + # At this point, inputs should already be partitioned by the sequence parallel data collator + # We'll just log some information about the partitioned data + batch_size = inputs["input_ids"].shape[0] + seq_len = inputs["input_ids"].shape[1] - self._update_ring_flash_attn_params(seq_lens, total_seq_len) + # Get rank and SP information + sp_group = get_ring_attn_group() + rank = dist.get_rank() + sp_rank = dist.get_rank(group=sp_group) if sp_group else rank + world_size = ( + dist.get_world_size(group=sp_group) + if sp_group + else dist.get_world_size() + ) - return super().training_step(model, inputs, num_items_in_batch) + # Sample tokens from our slice to verify partitioning + sample_start = ( + inputs["input_ids"][0, :5].tolist() + if seq_len >= 5 + else inputs["input_ids"][0, :].tolist() + ) + sample_end = ( + inputs["input_ids"][0, -5:].tolist() + if seq_len >= 5 + else inputs["input_ids"][0, :].tolist() + ) + + LOG.info( + f"GPU {rank} (SP rank {sp_rank}) | Step {self.state.global_step} | " + f"Slice shape: batch_size={batch_size}, seq_len={seq_len} | " + f"Sample start: {sample_start}, end: {sample_end}" + ) + + # Calculate the full sequence length across all GPUs in this SP group + full_seq_len = seq_len * world_size + + # Pass the partitioned sequence information to ring flash attention + self._update_ring_flash_attn_params([seq_len] * batch_size, full_seq_len) + + # Get the loss from the parent implementation + loss = super().training_step(model, inputs, num_items_in_batch) + + if self.args.sequence_parallel_size > 1: + rank = dist.get_rank() + LOG.info( + f"GPU {rank} | Step {self.state.global_step} | Loss: {loss.item()}" + ) + + return loss def _update_ring_flash_attn_params(self, packed_seq_lens, total_seq_len): """ diff --git a/src/axolotl/utils/ring_attn.py b/src/axolotl/monkeypatch/attention/ring_attn.py similarity index 66% rename from src/axolotl/utils/ring_attn.py rename to src/axolotl/monkeypatch/attention/ring_attn.py index 5552a9047..d6e245820 100644 --- a/src/axolotl/utils/ring_attn.py +++ b/src/axolotl/monkeypatch/attention/ring_attn.py @@ -1,4 +1,6 @@ -"""Ring attention group registration and utils.""" +"""Ring attention group registration and flash attention patching.""" + +from typing import Any import torch.distributed as dist from accelerate.logging import get_logger @@ -12,11 +14,11 @@ LOG = get_logger(__name__) RING_ATTN_GROUP = None -def get_ring_attn_group(): +def get_ring_attn_group() -> Any: return RING_ATTN_GROUP -def set_ring_attn_group(ring_attn_group): +def set_ring_attn_group(ring_attn_group: Any): global RING_ATTN_GROUP # pylint: disable=global-statement RING_ATTN_GROUP = ring_attn_group @@ -39,6 +41,10 @@ def register_ring_attn(sequence_parallel_size: int): f"must evenly divide world_size ({world_size})" ) + # Detailed logging of group formation + rank = dist.get_rank() + group_assignments = {} + for i in range(world_size // sequence_parallel_size): ring_attn_ranks = list( range( @@ -47,7 +53,19 @@ def register_ring_attn(sequence_parallel_size: int): ) ) group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") - if dist.get_rank() in ring_attn_ranks: + + # Track which GPUs are in which groups + for r in ring_attn_ranks: + group_assignments[r] = i + + if rank in ring_attn_ranks: set_ring_attn_group(group) + LOG.info( + f"GPU {rank} assigned to sequence parallel group {i} with ranks {ring_attn_ranks}" + ) + + # Log the full group assignment structure + if rank == 0: + LOG.info(f"Sequence parallel group assignments: {group_assignments}") substitute_hf_flash_attn(get_ring_attn_group(), sequence_parallel_size) diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 8c60f223c..66105d20d 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -9,3 +9,8 @@ from .batching import ( # noqa: F401 V2BatchSamplerDataCollatorForSeq2Seq, ) from .mamba import MambaDataCollator # noqa: F401 +from .sequence_parallel import ( # noqa: F401 + SequenceParallelDataCollator, + SequenceParallelPackedDataCollator, + V2SequenceParallelPackedDataCollator, +) diff --git a/src/axolotl/utils/collators/sequence_parallel.py b/src/axolotl/utils/collators/sequence_parallel.py new file mode 100644 index 000000000..b41a07b60 --- /dev/null +++ b/src/axolotl/utils/collators/sequence_parallel.py @@ -0,0 +1,433 @@ +"""Module for sequence parallelism data collators.""" + +from dataclasses import dataclass + +import torch +import torch.distributed as dist +from accelerate.logging import get_logger + +from axolotl.logging_config import configure_logging +from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group +from axolotl.utils.collators.batching import ( + BatchSamplerDataCollatorForSeq2Seq, + DataCollatorForSeq2Seq, + V2BatchSamplerDataCollatorForSeq2Seq, +) + +configure_logging() +LOG = get_logger(__name__) + + +def find_sample_boundaries(position_ids): + """ + Find the boundaries between packed samples in a sequence by looking for + where position_ids decrease. + + Returns: + List of boundary indices for each sequence in the batch + """ + batch_boundaries = [] + + for i in range(position_ids.shape[0]): + seq = position_ids[i] + boundaries = [] + for j in range(1, len(seq)): + if seq[j] < seq[j - 1]: + boundaries.append(j) + batch_boundaries.append(boundaries) + + return batch_boundaries + + +def adjust_position_ids_for_slice(position_ids, start_idx): + """ + Adjust position IDs for a sliced sequence to maintain proper relative positions. + This handles the case where position IDs might not be contiguous due to sample packing. + """ + # Convert to tensor if not already + if not isinstance(position_ids, torch.Tensor): + position_ids = torch.tensor( + position_ids, + device=position_ids.device if hasattr(position_ids, "device") else None, + ) + + # Find the boundaries between samples (where position_ids reset) + adjusted_pos_ids = position_ids.clone() + + # Process each sequence in the batch + for i in range(position_ids.shape[0]): + seq = position_ids[i] + + # Find sample boundaries + boundaries = [] + for j in range(1, len(seq)): + if seq[j] < seq[j - 1]: + boundaries.append(j) + + # Debug: log the found boundaries + LOG.debug(f"Sequence {i}: Found sample boundaries at positions {boundaries}") + + # No need to adjust if there are no boundaries or this is a single sample + if not boundaries: + old_values = seq[0:5].tolist() # Sample of original values + adjusted_pos_ids[i] = seq - start_idx + new_values = adjusted_pos_ids[i, 0:5].tolist() # Sample of new values + LOG.debug( + f"Sequence {i}: No boundaries, subtracting {start_idx} uniformly. Example values before: {old_values}, after: {new_values}" + ) + continue + + # Adjust each segment separately + prev_boundary = 0 + for boundary_idx, boundary in enumerate(boundaries): + segment = seq[prev_boundary:boundary] + old_values = segment[ + 0 : min(5, len(segment)) + ].tolist() # Sample of original values + adjusted_pos_ids[i, prev_boundary:boundary] -= start_idx + new_values = adjusted_pos_ids[ + i, prev_boundary : min(prev_boundary + 5, boundary) + ].tolist() # Sample of new values + LOG.debug( + f"Sequence {i}, Segment {boundary_idx}: Adjusting positions {prev_boundary}-{boundary-1}. Example values before: {old_values}, after: {new_values}" + ) + prev_boundary = boundary + + # Last segment + segment = seq[prev_boundary:] + old_values = segment[ + 0 : min(5, len(segment)) + ].tolist() # Sample of original values + adjusted_pos_ids[i, prev_boundary:] -= start_idx + new_values = adjusted_pos_ids[ + i, prev_boundary : min(prev_boundary + 5, len(seq)) + ].tolist() # Sample of new values + LOG.debug( + f"Sequence {i}, Last segment: Adjusting positions {prev_boundary}-end. Example values before: {old_values}, after: {new_values}" + ) + + return adjusted_pos_ids + + +def check_for_boundary_splits(boundaries, slice_start, slice_end): + """ + Check if any sample boundaries fall near the edge of a sequence slice. + These edge cases could cause issues with gradient computation. + + Args: + boundaries: List of indices where sample boundaries occur + slice_start: Start index of this GPU's slice + slice_end: End index of this GPU's slice + + Returns: + List of potentially problematic boundaries + """ + # Consider a boundary "near" an edge if it's within 5 tokens + buffer_size = 5 + problem_boundaries = [] + + for boundary in boundaries: + # Check if boundary is near the start of the slice + if slice_start <= boundary < slice_start + buffer_size: + problem_boundaries.append((boundary, "start", boundary - slice_start)) + # Check if boundary is near the end of the slice + elif slice_end - buffer_size <= boundary < slice_end: + problem_boundaries.append((boundary, "end", slice_end - boundary)) + + return problem_boundaries + + +@dataclass +class SequenceParallelPackedDataCollator(BatchSamplerDataCollatorForSeq2Seq): + """ + Data collator for sequence parallelism with sample packing. + Combines multiple samples into a packed sequence, then slices it for each GPU. + """ + + debug_level: str = "debug" # Can be "debug" for more verbose output + + def __call__(self, features, return_tensors=None): + # First, use the parent collator to handle sample packing and padding + batch = super().__call__(features, return_tensors=return_tensors) + + sp_group = get_ring_attn_group() + if sp_group is None: + return batch # Not using sequence parallelism + + # Get information about our position in the SP group + rank = dist.get_rank(group=sp_group) + world_size = dist.get_world_size(group=sp_group) + + # Enable debug level if requested + if self.debug_level == "debug": + original_shapes = { + k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() + } + LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") + + if "position_ids" in batch: + # Find and log sample boundaries before slicing + boundaries = find_sample_boundaries(batch["position_ids"]) + for i, seq_boundaries in enumerate(boundaries): + LOG.info( + f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}" + ) + + # Process keys that need to be sliced + for key in ["input_ids", "attention_mask", "labels"]: + if key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + LOG.info( + f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})" + ) + + if self.debug_level == "debug" and key == "input_ids": + # Log portions of the input to verify correct slicing + for i in range( + min(2, batch[key].shape[0]) + ): # Look at up to 2 sequences + # Sample the beginning, middle and end of the sequence before slicing + start_sample = batch[key][i, 0:5].tolist() + mid_sample = batch[key][ + i, seq_len // 2 : seq_len // 2 + 5 + ].tolist() + end_sample = batch[key][i, -5:].tolist() + LOG.info( + f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}" + ) + + batch[key] = batch[key][:, start_idx:end_idx] + + if self.debug_level == "debug" and key == "input_ids": + # Log after slicing to verify + for i in range(min(2, batch[key].shape[0])): + sliced_sample = batch[key][i, 0:5].tolist() + sliced_end = batch[key][i, -5:].tolist() + LOG.info( + f"GPU {rank}, Seq {i} after slicing: start={sliced_sample}, end={sliced_end}" + ) + + # Handle position_ids specially if present (important for packed sequences) + if "position_ids" in batch: + # For position_ids, we need to adjust them after slicing + # Each position_id should be relative to its slice + pos_ids = batch["position_ids"] + seq_len = pos_ids.shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + # Find boundaries before slicing + if self.debug_level == "debug": + full_boundaries = find_sample_boundaries(pos_ids) + + # Check for boundaries that fall near slice edges + for i, boundaries in enumerate(full_boundaries): + problem_boundaries = check_for_boundary_splits( + boundaries, start_idx, end_idx + ) + if problem_boundaries: + LOG.warning( + f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}" + ) + + batch["position_ids"] = pos_ids[:, start_idx:end_idx] + + # Find boundaries after slicing to verify correct transfer + if self.debug_level == "debug": + sliced_boundaries = find_sample_boundaries(batch["position_ids"]) + for i, boundaries in enumerate(sliced_boundaries): + LOG.info( + f"GPU {rank}: After slicing, sequence {i} has boundaries at {boundaries}" + ) + + # Adjust position_ids to be relative to the start of this slice + # Only subtract if not the first GPU in the group + if rank > 0: + # Find boundaries between samples in the position_ids + # This preserves the sample packing structure + old_pos_ids = batch["position_ids"].clone() + batch["position_ids"] = adjust_position_ids_for_slice( + batch["position_ids"], start_idx + ) + + if self.debug_level == "debug": + # Compare before and after adjustment + for i in range(min(2, old_pos_ids.shape[0])): + before = old_pos_ids[i, 0:10].tolist() + after = batch["position_ids"][i, 0:10].tolist() + LOG.info( + f"GPU {rank}, Seq {i} position_ids adjustment: before={before}, after={after}" + ) + + # Add gradient norm tracking for debugging + if self.debug_level == "debug": + # Attach hook to track gradient norms during backward pass + def hook_fn(grad): + norm = grad.norm().item() + LOG.info(f"GPU {rank}: Gradient norm = {norm:.4f}") + # Record any abnormally high gradients + if norm > 10.0: + LOG.warning(f"GPU {rank}: High gradient norm detected: {norm:.4f}") + return grad + + # Apply hook to input_ids embeddings if it goes through backward pass + if "input_ids" in batch and batch["input_ids"].requires_grad: + batch["input_ids"].register_hook(hook_fn) + + return batch + + +@dataclass +class V2SequenceParallelPackedDataCollator(V2BatchSamplerDataCollatorForSeq2Seq): + """ + Data collator for sequence parallelism with V2 sample packing. + """ + + debug_level: str = "debug" # Can be "debug" for more verbose output + + def __call__(self, features, return_tensors=None): + # Implementation similar to SequenceParallelPackedDataCollator with V2 base + # First, use the parent collator to handle sample packing and padding + batch = super().__call__(features, return_tensors=return_tensors) + + sp_group = get_ring_attn_group() + if sp_group is None: + return batch # Not using sequence parallelism + + # Get information about our position in the SP group + rank = dist.get_rank(group=sp_group) + world_size = dist.get_world_size(group=sp_group) + + # Enable debug level if requested + if self.debug_level == "debug": + original_shapes = { + k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() + } + LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") + + if "position_ids" in batch: + # Find and log sample boundaries before slicing + boundaries = find_sample_boundaries(batch["position_ids"]) + for i, seq_boundaries in enumerate(boundaries): + LOG.info( + f"GPU {rank}: Sequence {i} has {len(seq_boundaries)} packed samples with boundaries at {seq_boundaries}" + ) + + # Process keys that need to be sliced + for key in ["input_ids", "attention_mask", "labels"]: + if key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + if self.debug_level == "debug" and key == "input_ids": + # Log portions of the input to verify correct slicing + for i in range( + min(2, batch[key].shape[0]) + ): # Look at up to 2 sequences + # Sample the beginning, middle and end of the sequence before slicing + start_sample = batch[key][i, 0:5].tolist() + mid_sample = batch[key][ + i, seq_len // 2 : seq_len // 2 + 5 + ].tolist() + end_sample = batch[key][i, -5:].tolist() + LOG.info( + f"GPU {rank}, Seq {i} before slicing: start={start_sample}, mid={mid_sample}, end={end_sample}" + ) + + batch[key] = batch[key][:, start_idx:end_idx] + + # Handle position_ids specially (same as in SequenceParallelPackedDataCollator) + if "position_ids" in batch: + # Implementation identical to the one in SequenceParallelPackedDataCollator + pos_ids = batch["position_ids"] + seq_len = pos_ids.shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + # Find boundaries before slicing + if self.debug_level == "debug": + full_boundaries = find_sample_boundaries(pos_ids) + + # Check for boundaries that fall near slice edges + for i, boundaries in enumerate(full_boundaries): + problem_boundaries = check_for_boundary_splits( + boundaries, start_idx, end_idx + ) + if problem_boundaries: + LOG.warning( + f"GPU {rank}: Sequence {i} has sample boundaries near slice edges: {problem_boundaries}" + ) + + batch["position_ids"] = pos_ids[:, start_idx:end_idx] + + # Adjust position_ids to be relative to the start of this slice + if rank > 0: + batch["position_ids"] = adjust_position_ids_for_slice( + batch["position_ids"], start_idx + ) + + return batch + + +@dataclass +class SequenceParallelDataCollator(DataCollatorForSeq2Seq): + """ + Data collator for sequence parallelism without sample packing. + """ + + debug_level: str = "debug" # Can be "debug" for more verbose output + + def __call__(self, features, return_tensors=None): + # First, use the parent collator to pad everything correctly + batch = super().__call__(features, return_tensors=return_tensors) + + sp_group = get_ring_attn_group() + if sp_group is None: + return batch # Not using sequence parallelism + + # Get information about our position in the SP group + rank = dist.get_rank(group=sp_group) + world_size = dist.get_world_size(group=sp_group) + + if self.debug_level == "debug": + original_shapes = { + k: v.shape if hasattr(v, "shape") else None for k, v in batch.items() + } + LOG.info(f"GPU {rank}: Original batch shapes: {original_shapes}") + + # Process keys that need to be sliced + for key in ["input_ids", "attention_mask", "labels"]: + if key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + LOG.info( + f"GPU {rank}: Slicing {key} from {start_idx} to {end_idx} (total len: {seq_len})" + ) + batch[key] = batch[key][:, start_idx:end_idx] + + # Handle position_ids if present + if "position_ids" in batch: + pos_ids = batch["position_ids"] + seq_len = pos_ids.shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + + batch["position_ids"] = pos_ids[:, start_idx:end_idx] + + # For non-packed sequences, we can simply subtract start_idx from all position_ids + if rank > 0: + batch["position_ids"] -= start_idx + + return batch diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3f2374608..3fc971ca6 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -548,7 +548,7 @@ class ModelLoader: patch_self_attn_lora(self.cfg) if self.cfg.sequence_parallel_size > 1: - from axolotl.utils.ring_attn import register_ring_attn + from axolotl.monkeypatch.attention.ring_attn import register_ring_attn # Initialize ring attention for sequence parallelism if enabled. # This must be done after model initialization but before the first forward pass, diff --git a/tests/e2e/multigpu/test_sequence_parallelism.py b/tests/e2e/multigpu/test_sequence_parallelism.py new file mode 100644 index 000000000..9619cf690 --- /dev/null +++ b/tests/e2e/multigpu/test_sequence_parallelism.py @@ -0,0 +1,114 @@ +"""Tests for end-to-end sequence parallelism integration.""" +import os +import tempfile + +import pytest +import torch +import yaml + +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + + +def test_integration_with_config(): + """Test end-to-end training configuration setup for sequence parallelism.""" + # Define a test config directly in code instead of loading from file + config_dict = { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "LlamaTokenizer", + "is_llama_derived_model": True, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + } + ], + "load_in_8bit": False, + "sequence_len": 1024, + "sequence_parallel_size": 2, + "flash_attention": True, + "sample_packing": True, + "pad_to_sequence_len": True, + "micro_batch_size": 2, + "num_epochs": 1, + "max_steps": 10, + "gradient_accumulation_steps": 1, + "warmup_steps": 2, + "optimizer": "adamw_bnb_8bit", + "lr_scheduler": "cosine", + "learning_rate": 2.0e-4, + "weight_decay": 0.0, + "val_set_size": 0.05, + "eval_steps": 5, + "save_steps": 10, + } + + # Create a temp dir for output + with tempfile.TemporaryDirectory() as temp_dir: + config_dict["output_dir"] = temp_dir + + # Also write to a file for completeness + config_path = os.path.join(temp_dir, "sp_config.yml") + with open(config_path, "w", encoding="utf-8") as f: + yaml.dump(config_dict, f) + + # Convert to DictDefault and validate + cfg = DictDefault(config_dict) + cfg = validate_config(cfg) + normalize_config(cfg) + + # Verify sequence parallelism settings were properly processed + assert cfg.sequence_parallel_size == 2 + assert cfg.flash_attention is True + + # Check if the sequence_parallel_size was propagated to the training args + from axolotl.core.training_args import AxolotlTrainingArguments + + # pylint: disable=unexpected-keyword-arg + training_args = AxolotlTrainingArguments( + output_dir=temp_dir, sequence_parallel_size=cfg.sequence_parallel_size + ) + assert training_args.sequence_parallel_size == 2 + + +def test_ring_attn_group_creation(): + """Test that ring attention groups are properly created in a multi-GPU environment.""" + # First ensure we're in a distributed environment + if not torch.distributed.is_initialized(): + # Skip this test if not in distributed mode + pytest.skip( + "This test requires a properly initialized torch.distributed environment" + ) + + from axolotl.monkeypatch.attention.ring_attn import ( + get_ring_attn_group, + register_ring_attn, + ) + + # Get the current rank and world size + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + + # Only run if we have an even number of GPUs + if world_size % 2 != 0: + pytest.skip(f"Need an even number of GPUs, but got {world_size}") + + # Register with sequence parallel size of 2 + register_ring_attn(sequence_parallel_size=2) + + # Get the ring attention group + group = get_ring_attn_group() + + # Verify the group exists + assert group is not None + + # Calculate expected group members + group_id = rank // 2 + expected_start = group_id * 2 + expected_group = list(range(expected_start, expected_start + 2)) + + # Verify our rank is in the expected group + assert rank in expected_group + + # Clean up by synchronizing all processes + torch.distributed.barrier() diff --git a/tests/e2e/patched/test_sequence_parallelism.py b/tests/e2e/patched/test_sequence_parallelism.py new file mode 100644 index 000000000..6d7c64305 --- /dev/null +++ b/tests/e2e/patched/test_sequence_parallelism.py @@ -0,0 +1,221 @@ +"""Tests for sequence parallelism functionality.""" +# pylint: disable=redefined-outer-name,unused-argument + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from accelerate.state import PartialState + +# Use a single patch for ring_flash_attn if it's not available +ring_flash_attn_mock = MagicMock() +with patch.dict("sys.modules", {"ring_flash_attn": ring_flash_attn_mock}): + from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group + from axolotl.utils.collators.sequence_parallel import ( + adjust_position_ids_for_slice, + check_for_boundary_splits, + find_sample_boundaries, + ) + + +# Create a fixture for PartialState +@pytest.fixture +def partial_state(): + """Create a real PartialState instance for testing.""" + # This initializes a PartialState for a non-distributed environment + state = PartialState() + return state + + +class TestSequenceParallelHelpers: + """Test helper functions used in sequence parallelism.""" + + def test_find_sample_boundaries(self): + """Test detection of boundaries in position_ids.""" + # Create sample position_ids with multiple sequences + position_ids = torch.tensor( + [ + # First sequence with 2 samples (boundary at index 5) + [0, 1, 2, 3, 4, 0, 1, 2, 3], + # Second sequence with 3 samples (boundaries at 3 and 7) + [0, 1, 2, 0, 1, 2, 3, 0, 1], + ] + ) + + boundaries = find_sample_boundaries(position_ids) + + assert len(boundaries) == 2 + assert boundaries[0] == [5] # First sequence has boundary at index 5 + assert boundaries[1] == [3, 7] # Second sequence has boundaries at 3 and 7 + + def test_adjust_position_ids_for_slice(self, partial_state): + """Test position_ids adjustment for sequence slices.""" + # Create sample position_ids with multiple sequences + position_ids = torch.tensor( + [ + # First sequence with 2 samples + [0, 1, 2, 3, 4, 0, 1, 2, 3], + # Second sequence with 3 samples + [0, 1, 2, 0, 1, 2, 3, 0, 1], + ] + ) + + # Adjust as if this was the second slice (start_idx = 4) + adjusted = adjust_position_ids_for_slice(position_ids, start_idx=4) + + # For first sequence: [0,1,2,3,4,0,1,2,3] -> [-4,-3,-2,-1,0,-4,-3,-2,-1] + # For second sequence: [0,1,2,0,1,2,3,0,1] -> [-4,-3,-2,-4,-3,-2,-1,-4,-3] + expected_first_seq = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3]) - 4 + expected_second_seq = torch.tensor([0, 1, 2, 0, 1, 2, 3, 0, 1]) - 4 + + assert torch.all(adjusted[0] == expected_first_seq) + assert torch.all(adjusted[1] == expected_second_seq) + + def test_check_for_boundary_splits(self): + """Test detection of boundaries near slice edges.""" + # Boundaries at positions 10, 25, 40 + boundaries = [10, 25, 40] + + # Test case where two boundaries are near edges (one at start, one at end) + problems = check_for_boundary_splits(boundaries, slice_start=8, slice_end=30) + assert ( + len(problems) == 2 + ) # Both boundary at 10 (near start) and 25 (near end) are problems + + # Check first problem - boundary near start + assert problems[0][0] == 10 # The boundary position + assert problems[0][1] == "start" # Type of issue + assert problems[0][2] == 2 # Distance from start + + # Check second problem - boundary near end + assert problems[1][0] == 25 # The boundary position + assert problems[1][1] == "end" # Type of issue + assert problems[1][2] == 5 # Distance from end + + # Test case with only one problem at the end + problems = check_for_boundary_splits(boundaries, slice_start=15, slice_end=27) + assert len(problems) == 1 # Only boundary at 25 is near the end + assert problems[0][0] == 25 # The boundary + assert problems[0][1] == "end" # Type of issue + + # Test case with no problems + problems = check_for_boundary_splits(boundaries, slice_start=12, slice_end=20) + assert len(problems) == 0 + + +class TestRingAttention: + """Tests for the ring attention functionality.""" + + @patch("torch.distributed.new_group") + @patch("torch.distributed.get_rank") + @patch("torch.distributed.get_world_size") + def test_register_ring_attn( + self, mock_world_size, mock_rank, mock_new_group, partial_state + ): + """Test that ring attention groups are created correctly.""" + from axolotl.monkeypatch.attention.ring_attn import register_ring_attn + + # Setup mocks + mock_world_size.return_value = 8 # 8 GPUs total + mock_rank.return_value = 3 # GPU #3 + mock_group = MagicMock() + mock_new_group.return_value = mock_group + + # Call register_ring_attn with size 4 + register_ring_attn(sequence_parallel_size=4) + + # Verify the number of calls without examining the arguments + assert mock_new_group.call_count == 2 + + # Just verify that new_group was called + mock_new_group.assert_called() + + @patch("torch.distributed.get_rank") + @patch("torch.distributed.get_world_size") + def test_get_ring_attn_group_no_registration( + self, mock_world_size, mock_rank, partial_state + ): + """Test that get_ring_attn_group returns None when no group has been registered.""" + # Setup mocks + mock_world_size.return_value = 4 + mock_rank.return_value = 0 + + # Get the group without registration + group = get_ring_attn_group() + + # Verify that None was returned + assert group is None + + +# Mock a simplified DataCollator test +@patch("axolotl.utils.collators.sequence_parallel.get_ring_attn_group") +@patch("torch.distributed.get_rank") +@patch("torch.distributed.get_world_size") +def test_sequence_parallel_slicing( + mock_world_size, mock_rank, mock_get_group, partial_state +): + """Test the basic sequence slicing logic without full collator instantiation.""" + # Setup mocks + mock_get_group.return_value = MagicMock() + mock_rank.return_value = 1 # Second GPU + mock_world_size.return_value = 4 # 4 GPUs total + + # Create a sample batch + batch = { + "input_ids": torch.tensor( + [ + [101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112], + [201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212], + ] + ), + "attention_mask": torch.ones(2, 12), + } + + # Simplified slicing logic from SequenceParallelDataCollator + def slice_batch(batch, rank, world_size): + result = {} + for key in batch: + seq_len = batch[key].shape[1] + slice_size = seq_len // world_size + start_idx = rank * slice_size + end_idx = start_idx + slice_size if rank < world_size - 1 else seq_len + result[key] = batch[key][:, start_idx:end_idx] + return result + + # Slice the batch + result = slice_batch( + batch, rank=mock_rank.return_value, world_size=mock_world_size.return_value + ) + + # Check slicing + assert result["input_ids"].shape == (2, 3) # 12 tokens / 4 GPUs = 3 tokens per GPU + expected_input_ids = torch.tensor( + [ + [104, 105, 106], # Second slice of first sequence + [204, 205, 206], # Second slice of second sequence + ] + ) + assert torch.all(result["input_ids"] == expected_input_ids) + + +# Simple test for configuration validation +@pytest.mark.parametrize( + "config,should_validate", + [ + ({"sequence_parallel_size": 2, "flash_attention": True}, True), + ({"sequence_parallel_size": 2, "flash_attention": False}, False), + ({"sequence_parallel_size": 1, "flash_attention": False}, True), + ], +) +def test_sequence_parallel_config_requirements(config, should_validate): + """Test basic sequence parallelism configuration requirements.""" + + # Simple validation function that mimics the actual validator + def validate_sp_config(config): + if config.get("sequence_parallel_size", 1) > 1 and not config.get( + "flash_attention", False + ): + return False + return True + + assert validate_sp_config(config) == should_validate