Files
axolotl/docs/sequence_parallelism.qmd

101 lines
3.6 KiB
Plaintext

---
title: Context Parallelism
description: Train with long sequences split across multiple GPUs.
---
Context parallelism is a technique that splits sequences across multiple GPUs,
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
GPU processes a different portion of the sequence, and the results are aggregated
through a ring communication pattern.
## When to Use Context Parallelism
Use context parallelism when:
- You need to train with sequence lengths that don't fit into a single GPU's memory
- You have multiple GPUs available
- You're experiencing OOM (Out Of Memory) errors with long sequences
## Configuration
To enable context parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
context_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
```
The `context_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
## Implementation Details
When context parallelism is enabled:
1. Each sequence is divided into equal chunks across the GPUs in a context parallel group
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
3. Position IDs are adjusted to maintain proper relative positions
4. The trainer uses special ring communication patterns for attention operations
## Requirements
To use context parallelism, you need:
- Multiple GPUs (at least 2)
- The `ring-flash-attn` package. Install with:
- `pip install axolotl[ring-flash-attn]` (preferred)
- `pip install ring-flash-attn>=0.1.4`
## Limitations
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
- May have a small performance overhead due to communication between GPUs
## Example
```yaml
base_model: meta-llama/Llama-3-8B-Instruct
sequence_len: 8192
...
context_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
ring_attn_func:
...
```
This will train the Llama 3 8B model with 8K context length, with each sequence split
into 2 subsequences of length 4096 across 2 GPUs.
## Sample Packing with Context Parallelism
Context parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
1. Samples are first packed together
2. The packed sequences are then divided across GPUs in the context parallel group
3. Position IDs are automatically adjusted to maintain proper relative positions
## Effect on Batch Size
When using context parallelism, your effective global batch size is **divided** by the `context_parallel_degree`. This happens because:
- Each group of `context_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no context parallelism: 8 different batches processed per step
- With 8 GPUs and `context_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4