Optimizations Guide
+Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
+This guide provides a high-level overview and directs you to the detailed documentation for each feature.
+Speed Optimizations
+These optimizations focus on increasing training throughput and reducing total training time.
+Sample Packing
+Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the attention implementations below.
+-
+
- Config:
sample_packing: true
+ - Learn more: Sample Packing +
Attention Implementations
+Using an optimized attention implementation is critical for training speed.
+-
+
- Flash Attention 2:
flash_attention: true. (Recommended) The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check AMD Support.
+ - Flex Attention:
flex_attention: true.
+ - SDP Attention:
sdp_attention: true. PyTorch’s native implementation.
+ - Xformers:
xformers_attention: true. Works with FP16.
+
Note: You should only enable one attention backend.
+LoRA Optimizations
+Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
+-
+
- Learn more: LoRA Optimizations Documentation +
Memory Optimizations
+These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
+Parameter Efficient Finetuning (LoRA & QLoRA)
+Drastically reduces memory by training a small set of “adapter” parameters instead of the full model. This is the most common and effective memory-saving technique.
+-
+
- Examples: Find configs with
loraorqlorain the examples directory.
+ - Config Reference: See
adapter,load_in_4bit, andload_in_8bitin the Configuration Reference.
+
Gradient Checkpointing & Activation Offloading
+These techniques save VRAM by changing how activations are handled.
+-
+
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM. +
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM. +
- Learn more: Gradient Checkpointing and Offloading Docs +
Cut Cross Entropy (CCE)
+Reduces VRAM usage by using an optimized cross-entropy loss calculation.
+-
+
- Learn more: Custom Integrations - CCE +
Liger Kernels
+Provides efficient Triton kernels to improve training speed and reduce memory usage.
+-
+
- Learn more: Custom Integrations - Liger Kernels +
Long Context Models
+Techniques to train models on sequences longer than their original context window.
+RoPE Scaling
+Extends a model’s context window by interpolating its Rotary Position Embeddings.
+-
+
- Config: Pass the
rope_scalingconfig under theoverrides_of_model_config:. To learn how to set RoPE, check the respective model config.
+
Sequence Parallelism
+Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
+-
+
- Learn more: Sequence Parallelism Documentation +
Artic Long Sequence Training (ALST)
+ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
+-
+
TiledMLP to reduce memory usage in MLP layers.
+Tiled Loss functions (like CCE.
+Activation Offloading to CPU.
+Example: ALST Example Configuration
+
Large Models (Distributed Training)
+To train models that don’t fit on a single GPU, you’ll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
+-
+
- Learn more: Multi-GPU Guide +
- Learn more: Multi-Node Guide +
N-D Parallelism (Beta)
+For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
+-
+
- Learn more: N-D Parallelism Guide +
Quantization
+Techniques to reduce the precision of model weights for memory savings.
+4-bit Training (QLoRA)
+The recommended approach for quantization-based training. It loads the base model in 4-bit using bitsandbytes and then trains QLoRA adapters. See Adapter Finetuning for details.
FP8 Training
+Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
+-
+
- Example: Llama 3 FP8 FSDP Example +
Quantization Aware Training (QAT)
+Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
+-
+
- Learn more: QAT Documentation +
GPTQ
+Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
+-
+
- Example: GPTQ LoRA Example +