* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
101 lines
3.7 KiB
Plaintext
101 lines
3.7 KiB
Plaintext
---
|
|
title: Sequence Parallelism
|
|
description: Train with long sequences split across multiple GPUs.
|
|
---
|
|
|
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
|
GPU processes a different portion of the sequence, and the results are aggregated
|
|
through a ring communication pattern.
|
|
|
|
## When to Use Sequence Parallelism
|
|
|
|
Use sequence parallelism when:
|
|
|
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
|
- You have multiple GPUs available
|
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
|
|
|
## Configuration
|
|
|
|
To enable sequence parallelism, add the following to your configuration file:
|
|
|
|
```yaml
|
|
# Set to a divisor (> 1) of the number of GPUs available
|
|
context_parallel_size: 4 # Split sequences across 4 GPUs
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
|
ring_attn_func:
|
|
```
|
|
|
|
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
|
|
|
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
|
- With 4 GPUs, valid values would be 2 or 4
|
|
|
|
## Implementation Details
|
|
|
|
When sequence parallelism is enabled:
|
|
|
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
|
3. Position IDs are adjusted to maintain proper relative positions
|
|
4. The trainer uses special ring communication patterns for attention operations
|
|
|
|
## Requirements
|
|
|
|
To use sequence parallelism, you need:
|
|
|
|
- Multiple GPUs (at least 2)
|
|
- The `ring-flash-attn` package. Install with:
|
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
|
- `pip install ring-flash-attn>=0.1.4`
|
|
|
|
## Limitations
|
|
|
|
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
|
- May have a small performance overhead due to communication between GPUs
|
|
|
|
## Example
|
|
|
|
```yaml
|
|
base_model: meta-llama/Llama-3-8B-Instruct
|
|
sequence_len: 8192
|
|
|
|
...
|
|
|
|
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
|
|
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
|
heads_k_stride: 1
|
|
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
|
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
|
ring_attn_func:
|
|
|
|
...
|
|
```
|
|
|
|
This will train the Llama 3 8B model with 8K context length, with each sequence split
|
|
into 2 subsequences of length 4096 across 2 GPUs.
|
|
|
|
## Sample Packing with Sequence Parallelism
|
|
|
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
|
|
|
1. Samples are first packed together
|
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
|
|
|
## Effect on Batch Size
|
|
|
|
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
|
|
|
|
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
|
|
- The number of batches processed per step decreases
|
|
|
|
For example:
|
|
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
|
|
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
|
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4
|