* ctx manager for SP * updates * update * further simplifying * simplifying * simplifying * reorg * batch api HF adapter for ring-flash-attn; cleanup and improvements * update * adding all batch ring-flash-attn methods via single adapter * fix * fixes for batch API funcs, simplify * fix * grpo sp support * progress * stronger subclassing of TRL GRPO trainer; custom distributed sampler * subclassing constructor * progress * finalizing SP + GRPO trainer * minimize diffs to GRPO trainer * remove (most of) the custom GRPO trainer logic * debug * debug * update * update * update * progress * cleanup * cleanup * minor changes * update * update * update * small changes * updates * cleanup; torch.compile ring_flash_attn functions to prevent numerical instability; lint * spacing * cleanup; log in pydantic model config only on main process * remove comment * fix sp sampler, update to latest upstream code, doc * add docs * update quartodoc autodoc contents * fix, simplifications * fixes + simplifications * review comments * lint * removing main process only logs in favor of #2608 * fixes, additional smoke test * updatse * more tests * update * fix grad accum bug (sort of) * lint, tests * todo
99 lines
3.6 KiB
Plaintext
99 lines
3.6 KiB
Plaintext
---
|
|
title: Sequence Parallelism
|
|
description: Train with long sequences split across multiple GPUs.
|
|
---
|
|
|
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
|
GPU processes a different portion of the sequence, and the results are aggregated
|
|
through a ring communication pattern.
|
|
|
|
## When to Use Sequence Parallelism
|
|
|
|
Use sequence parallelism when:
|
|
|
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
|
- You have multiple GPUs available
|
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
|
|
|
## Configuration
|
|
|
|
To enable sequence parallelism, add the following to your configuration file:
|
|
|
|
```yaml
|
|
# Set to a divisor (> 1) of the number of GPUs available
|
|
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
|
ring_attn_func:
|
|
```
|
|
|
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
|
|
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
|
- With 4 GPUs, valid values would be 2 or 4
|
|
|
|
## Implementation Details
|
|
|
|
When sequence parallelism is enabled:
|
|
|
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
|
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
|
4. The trainer uses special ring communication patterns for attention operations
|
|
|
|
## Requirements
|
|
|
|
To use sequence parallelism, you need:
|
|
|
|
- Multiple GPUs (at least 2)
|
|
- The `ring-flash-attn` package. Install with:
|
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
|
- `pip install ring-flash-attn>=0.1.4`
|
|
|
|
## Limitations
|
|
|
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
|
- May have a small performance overhead due to communication between GPUs
|
|
|
|
## Example
|
|
|
|
```yaml
|
|
base_model: meta-llama/Llama-3-8B-Instruct
|
|
sequence_len: 8192
|
|
|
|
...
|
|
|
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
|
flash_attention: true # Required with sequence parallelism
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
|
|
...
|
|
```
|
|
|
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
|
into 2 subsequences of length 4096 across 2 GPUs.
|
|
|
|
## Sample Packing with Sequence Parallelism
|
|
|
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
|
|
|
1. Samples are first packed together
|
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
|
|
|
## Effect on Batch Size
|
|
|
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
|
|
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
|
- The number of batches processed per step decreases
|
|
|
|
For example:
|
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|