add SP doc, review comments

This commit is contained in:
Dan Saunders
2025-03-18 20:04:48 +00:00
parent 411df76a97
commit c1a58339e8
7 changed files with 115 additions and 17 deletions

View File

@@ -623,6 +623,9 @@ ddp_broadcast_buffers:
# Sequence parallelism
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
# subsequences, or set to 4 to split into four equal-sized subsequences.
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
sequence_parallel_degree:
# Path to torch distx for optim 'adamw_anyprecision'

View File

@@ -0,0 +1,90 @@
---
title: Sequence Parallelism
description: Train with long sequences split across multiple GPUs.
---
# Sequence Parallelism
Sequence parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
Use sequence parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
# Example config with sequence parallelism
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
sequence_parallel_degree: 2 # Split each sequence into 4 parts
flash_attention: true # Required with sequence parallelism
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -64,6 +64,3 @@ schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3
# for sequence parallelism
yunchang==0.6.0

View File

@@ -119,7 +119,8 @@ setup(
],
},
extras_require={
"flash-attn": ["flash-attn==2.7.4.post1", "ring-flash-attn>=0.1.4"],
"flash-attn": ["flash-attn==2.7.4.post1"],
"ring-flash-attn": ["ring-flash-attn>=0.1.4", "yunchang==0.6.0"],
"deepspeed": [
"deepspeed==0.16.4",
"deepspeed-kernels",

View File

@@ -590,7 +590,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
if self.args.sample_packing and not self.args.pretraining:
train_dataset = train_dataset.remove_columns(["length"])
if not (self.args.sample_packing and not self.args.pretraining):
if not self.args.sample_packing or self.args.pretraining:
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)

View File

@@ -166,10 +166,22 @@ def setup_signal_handler(
)
def train_context_manager(enable=False) -> contextlib.AbstractContextManager:
"""Configure CUDA SDP kernel settings if enabled."""
if enable:
def train_context_manager(
flash_optimum: bool = False,
) -> contextlib.AbstractContextManager:
"""
Instantiate CUDA SDP kernel context manager if `flash_optimum` is `True`.
Args:
flash_optimum: Whether to enable efficient backends for SDP attention.
Returns:
Context manager for temporarily enabling efficient backends for SDP attention
if `flash_optimum` is `True`, or `contextlib.nullcontext` otherwise.
"""
if flash_optimum:
return torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
@@ -190,7 +202,7 @@ def execute_training(
resume_from_checkpoint: Path to checkpoint to resume from, if applicable.
"""
LOG.info("Starting trainer...")
context_manager = train_context_manager(cfg.flash_optimum)
context_manager = train_context_manager(flash_optimum=cfg.flash_optimum)
with context_manager:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)

View File

@@ -17,19 +17,14 @@ logger = logging.getLogger(__name__)
def adjust_position_ids_for_slice(
position_ids: list | torch.Tensor, start_idx: int
position_ids: torch.Tensor, start_idx: int
) -> torch.Tensor:
"""
Adjust position IDs for a sliced sequence to maintain proper relative positions.
This handles the case where position IDs might not be contiguous due to sample packing.
This handles the case where position IDs might not be contiguous due to sample
packing.
"""
# Convert to tensor if not already
if not isinstance(position_ids, torch.Tensor):
position_ids = torch.tensor(
position_ids,
device=position_ids.device if hasattr(position_ids, "device") else None,
)
# Find the boundaries between samples (where position_ids reset)
adjusted_pos_ids = position_ids.clone()