101 lines
3.6 KiB
Plaintext
101 lines
3.6 KiB
Plaintext
---
|
|
title: Sequence Parallelism
|
|
description: Train with long sequences split across multiple GPUs.
|
|
---
|
|
|
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
|
GPU processes a different portion of the sequence, and the results are aggregated
|
|
through a ring communication pattern.
|
|
|
|
## When to Use Sequence Parallelism
|
|
|
|
Use sequence parallelism when:
|
|
|
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
|
- You have multiple GPUs available
|
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
|
|
|
## Configuration
|
|
|
|
To enable sequence parallelism, add the following to your configuration file:
|
|
|
|
```yaml
|
|
# Set to a divisor (> 1) of the number of GPUs available
|
|
context_parallel_size: 4 # Split sequences across 4 GPUs
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
|
ring_attn_func:
|
|
```
|
|
|
|
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
|
|
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
|
- With 4 GPUs, valid values would be 2 or 4
|
|
|
|
## Implementation Details
|
|
|
|
When sequence parallelism is enabled:
|
|
|
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
|
3. Position IDs are adjusted to maintain proper relative positions
|
|
4. The trainer uses special ring communication patterns for attention operations
|
|
|
|
## Requirements
|
|
|
|
To use sequence parallelism, you need:
|
|
|
|
- Multiple GPUs (at least 2)
|
|
- The `ring-flash-attn` package. Install with:
|
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
|
- `pip install ring-flash-attn>=0.1.4`
|
|
|
|
## Limitations
|
|
|
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
|
- May have a small performance overhead due to communication between GPUs
|
|
|
|
## Example
|
|
|
|
```yaml
|
|
base_model: meta-llama/Llama-3-8B-Instruct
|
|
sequence_len: 8192
|
|
|
|
...
|
|
|
|
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
|
ring_attn_func:
|
|
|
|
...
|
|
```
|
|
|
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
|
into 2 subsequences of length 4096 across 2 GPUs.
|
|
|
|
## Sample Packing with Sequence Parallelism
|
|
|
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
|
|
|
1. Samples are first packed together
|
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
|
|
|
## Effect on Batch Size
|
|
|
|
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
|
|
|
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
|
- The number of batches processed per step decreases
|
|
|
|
For example:
|
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
|
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|