* upgrade to torchao 0.17.0 * chore: lint * refactor attention handling * replace legacy attention boolean flags with capability properties Replace checks with capability-based properties derived from attn_implementation This separates three concerns that were conflated under flash_attention: 1. Backend selection -> attn_implementation enum 2. Packing capability -> attn_supports_packing property 3. Flash-attn library dependency -> attn_uses_flash_lib property * compute attn capability flags in normalizer instead of properties * make attn_implementation the single source of truth * move attention-dependent validators to mode=after * migrate remaining consumers to canonical attn_implementation * expand attention tests + rewrite docs * migrate example configs to canonical attn_implementation * update doc snippets + reject gemma4-hybrid with non-FA2 backend * remove dead gemma4 branch in _set_attention_config * fix duplicate attn_implementation in gpt-oss yamls and flaky caplog tests * drop "Phase 2" naming from attn-implementation tests * regroup attn_implementation tests by feature concern * clean up verbose comments and remove MD Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): pass return_dict=True at apply_chat_template top level for transformers 5.x In transformers 5.x, ProcessorMixin.apply_chat_template gained its own `return_dict` parameter (defaulting to False). When return_dict=False and tokenize=True the method returns out["input_ids"] directly — a 2-D tensor — rather than the full BatchFeature dict. The old code placed `return_dict=True` inside processor_kwargs. In transformers 5.x those kwargs are forwarded to the underlying processor call self(...) where _merge_kwargs silently ignores any key not present in MllamaProcessorKwargs (emitting a warning). The outer return_dict therefore stayed False, apply_chat_template returned the raw input_ids tensor, and the subsequent `batch["input_ids"]` attempted to index a 2-D tensor with the 9-character string "input_ids", producing: IndexError: too many indices for tensor of dimension 2 The fix is to pass return_dict=True as a top-level keyword argument to apply_chat_template (where it is actually consumed) and remove it from processor_kwargs (where it was silently dropped). No version guard is needed: transformers is pinned to ==5.5.4 in pyproject.toml. Adds a unit-level regression test (tests/test_mm_chat_collator.py) that mocks the processor to return a raw tensor when apply_chat_template is called without top-level return_dict=True, verifying the four invariants: process_rows returns a dict, input_ids is 2-D, labels is 2-D, and apply_chat_template receives return_dict=True as a top-level kwarg. Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_multimodal_dataset Fixes: tests/e2e/test_llama_vision.py::TestLlamaVision::test_lora_llama_vision_text_only_dataset Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai> * fix(collator): process_rows returns dict (BatchFeature) shape Two related changes for the multimodal chat collator under transformers 5.x: 1. Wrap apply_chat_template result in dict(...) so process_rows returns a plain dict rather than a BatchFeature instance. BatchFeature is a Mapping but not a dict; downstream code that did batch["labels"] = self.processing_strategy.process_labels(batch["input_ids"]) would index on a tensor when the result wasn't dict-shaped, raising IndexError: too many indices for tensor of dimension 2 2. Soften the regression test's contract from `dict` to `Mapping` so it exercises the actual semantic guarantee (key/value access) rather than the implementation detail (dict vs BatchFeature). Test guards against the original transformers 5.x breakage where apply_chat_template's return_dict default went from True to False. Includes regression test under tests/test_mm_chat_collator.py. Bug surfaced via swarm dispatch task_01KQHPNAYD8XARSNSDJVW1GPF6 against attn-implementation-refactor; squash-merged from agent commits 4de886fd + dc9fcf4f. Signed-off-by: Wing Lian <wing@axolotl.ai> --------- Signed-off-by: Wing Lian <wing@axolotl.ai> Co-authored-by: Axolotl Swarm <no-reply@axolotl.ai>
157 lines
7.0 KiB
Plaintext
157 lines
7.0 KiB
Plaintext
---
|
|
title: Optimizations Guide
|
|
description: A guide to the performance and memory optimizations available in Axolotl.
|
|
---
|
|
|
|
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
|
|
|
|
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
|
|
|
|
## Speed Optimizations
|
|
|
|
These optimizations focus on increasing training throughput and reducing total training time.
|
|
|
|
### Sample Packing
|
|
|
|
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
|
|
|
|
- **Config:** `sample_packing: true`
|
|
- **Learn more:** [Sample Packing](multipack.qmd)
|
|
|
|
### Attention Implementations
|
|
|
|
Using an optimized attention implementation is critical for training speed.
|
|
|
|
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
|
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
|
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
|
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
|
|
|
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
|
|
|
### LoRA Optimizations
|
|
|
|
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
|
|
|
|
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
|
|
|
|
## Memory Optimizations
|
|
|
|
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
|
|
|
|
### Parameter Efficient Finetuning (LoRA & QLoRA)
|
|
|
|
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
|
|
|
|
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
|
|
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
|
|
|
|
### Gradient Checkpointing & Activation Offloading
|
|
|
|
These techniques save VRAM by changing how activations are handled.
|
|
|
|
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
|
|
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
|
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
|
|
|
### Layer Offloading
|
|
|
|
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
|
|
|
- **Config:** `layer_offloading: true`
|
|
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
|
|
|
### Cut Cross Entropy (CCE)
|
|
|
|
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
|
|
|
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
|
|
|
|
### Liger Kernels
|
|
|
|
Provides efficient Triton kernels to improve training speed and reduce memory usage.
|
|
|
|
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
|
|
|
### Expert Kernels
|
|
|
|
Optimized kernel implementations for Mixture of Experts (MoE) model training.
|
|
|
|
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
|
|
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
|
|
|
|
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
|
|
|
|
## Long Context Models
|
|
|
|
Techniques to train models on sequences longer than their original context window.
|
|
|
|
### RoPE Scaling
|
|
|
|
Extends a model's context window by interpolating its Rotary Position Embeddings.
|
|
|
|
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
|
|
|
|
### Sequence Parallelism
|
|
|
|
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
|
|
|
|
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
|
|
|
|
### Artic Long Sequence Training (ALST)
|
|
|
|
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
|
|
|
|
- TiledMLP to reduce memory usage in MLP layers.
|
|
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
|
|
- Activation Offloading to CPU.
|
|
|
|
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
|
|
|
|
## Large Models (Distributed Training)
|
|
|
|
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
|
|
|
|
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
|
|
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
|
|
|
|
### N-D Parallelism (Beta)
|
|
|
|
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
|
|
|
|
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
|
|
|
|
|
|
## Quantization
|
|
|
|
Techniques to reduce the precision of model weights for memory savings.
|
|
|
|
### 4-bit Training (QLoRA)
|
|
|
|
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
|
|
|
|
### FP8 Training
|
|
|
|
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
|
|
|
|
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
|
|
|
|
### Quantization Aware Training (QAT)
|
|
|
|
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
|
|
|
|
- **Learn more:** [QAT Documentation](qat.qmd)
|
|
|
|
### GPTQ
|
|
|
|
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
|
|
|
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
|
|
|
### MoE Expert Quantization
|
|
|
|
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
|
|
|
|
- **Config:** `quantize_moe_experts: true`
|
|
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)
|