finish basic impl; change naming from SP -> CP to match torch

This commit is contained in:
Dan Saunders
2025-06-13 09:51:06 -04:00
parent aced809989
commit 7a88de4fa8
25 changed files with 525 additions and 488 deletions

View File

@@ -1,16 +1,16 @@
---
title: Sequence Parallelism
title: Context Parallelism
description: Train with long sequences split across multiple GPUs.
---
Sequence parallelism is a technique that splits sequences across multiple GPUs,
Context parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Sequence Parallelism
## When to Use Context Parallelism
Use sequence parallelism when:
Use context parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
@@ -18,11 +18,11 @@ Use sequence parallelism when:
## Configuration
To enable sequence parallelism, add the following to your configuration file:
To enable context parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
context_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,23 +30,23 @@ heads_k_stride: 1
ring_attn_func:
```
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When sequence parallelism is enabled:
When context parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use sequence parallelism, you need:
To use context parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
@@ -66,7 +66,7 @@ sequence_len: 8192
...
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -79,22 +79,22 @@ ring_attn_func:
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Sequence Parallelism
## Sample Packing with Context Parallelism
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the sequence parallel group
2. The packed sequences are then divided across GPUs in the context parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and no context parallelism: 8 different batches processed per step
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4