* debug * debug * debug * revert unneeded change * add accelerator config to base trainer builder * add back accumulated_cache_size_limit setting * lint * accelerator constructor patch for single-GPU torch fp8 * lint * re-using existing fp8 code * lint * remove accelerate patch now fix in latest release * fix * docs * add fp8 + fsdp2 example * remove unused config * update config * smoke tests * add validator * add 2.7.0 guard for fsdp2 * fix * add config descriptions * add FSDP doc link * nit * set force_recompute_fp8_weight_in_bwd with enable_fsdp_float8_all_gather * better cfg for smoke tests * add test for accelerate patching * update fp8 validator
150 lines
4.7 KiB
Plaintext
150 lines
4.7 KiB
Plaintext
---
|
|
title: "Mixed Precision Training"
|
|
format:
|
|
html:
|
|
toc: true
|
|
toc-depth: 3
|
|
number-sections: true
|
|
code-tools: true
|
|
execute:
|
|
enabled: false
|
|
---
|
|
|
|
Mixed precision training uses lower precision data types to reduce memory usage and increase training speed while maintaining model quality. Axolotl supports several mixed precision formats:
|
|
|
|
- **FP16** - Half precision 16-bit (Pascal generation+)
|
|
- **BF16** - Brain Float 16-bit (Ampere generation+)
|
|
- **FP8** - 8-bit floating point (Hopper generation+)
|
|
|
|
## FP16 Mixed Precision {#sec-fp16}
|
|
|
|
### Overview {#sec-fp16-overview}
|
|
|
|
FP16 is the traditional half-precision format, supported on older GPUs but can be less numerically stable than BF16.
|
|
|
|
### Configuration {#sec-fp16-config}
|
|
|
|
```{.yaml}
|
|
fp16: true
|
|
```
|
|
|
|
### FP16 Considerations {#sec-fp16-considerations}
|
|
|
|
- May require gradient scaling to prevent underflow
|
|
- Less numerically stable than BF16
|
|
- Can cause training instability with some model architectures
|
|
- Consider using BF16 if your hardware supports it
|
|
|
|
## BF16 Mixed Precision {#sec-bf16}
|
|
|
|
### Overview {#sec-bf16-overview}
|
|
|
|
BF16 (Brain Float 16) offers better numerical stability than FP16 and is the recommended mixed precision format for modern GPUs. It provides the same dynamic range as FP32 while using half the memory.
|
|
|
|
### Configuration {#sec-bf16-config}
|
|
|
|
```{.yaml}
|
|
# Automatic BF16 detection (recommended)
|
|
bf16: auto
|
|
|
|
# Or explicitly enable
|
|
bf16: true
|
|
|
|
# For evaluation with BF16
|
|
bf16: full # Equivalent to bf16_full_eval in the HF trainer
|
|
```
|
|
|
|
## FP8 Mixed Precision {#sec-fp8}
|
|
|
|
::: {.callout-note}
|
|
FP8 support is experimental and requires compatible hardware (H100, H200) and recent PyTorch versions with TorchAO.
|
|
:::
|
|
|
|
### What is FP8? {#sec-fp8-overview}
|
|
|
|
FP8 (8-bit floating point) can provide significant time savings compared to FP16/BF16 while maintaining training stability. Axolotl's implementation uses PyTorch's TorchAO library with "tensorwise" scaling strategy.
|
|
|
|
### Requirements {#sec-fp8-software}
|
|
|
|
- Hopper+ GPUs (H100/H200)
|
|
- PyTorch 2.7+ (+ compatible TorchAO version)
|
|
- CUDA 12.4+
|
|
|
|
### Configuration {#sec-fp8-config}
|
|
|
|
Add to your YAML config:
|
|
|
|
```{.yaml}
|
|
# Enable FP8 mixed precision
|
|
fp8: true
|
|
|
|
# Optional: Enable FP8 for FSDP all-gather operations
|
|
fp8_enable_fsdp_float8_all_gather: true
|
|
|
|
# Enable torch.compile (almost always necessary for FP8 speedups)
|
|
torch_compile: true
|
|
```
|
|
|
|
::: {.callout-important}
|
|
**torch.compile is critical for FP8 performance**
|
|
|
|
FP8 training requires `torch_compile: true` to see meaningful speedups. Without compilation, FP8 may actually be slower and use more memory than FP16/BF16.
|
|
:::
|
|
|
|
### Advanced FP8 Configs {#sec-fp8-advanced}
|
|
|
|
For [FSDP](multi-gpu.qmd#sec-fsdp) (Fully Sharded Data Parallel) training:
|
|
|
|
```{.yaml}
|
|
fp8: true
|
|
fp8_enable_fsdp_float8_all_gather: true
|
|
|
|
torch_compile: true
|
|
|
|
# FSDP configuration
|
|
fsdp_version: 2
|
|
fsdp_config:
|
|
offload_params: false
|
|
cpu_ram_efficient_loading: true
|
|
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
state_dict_type: FULL_STATE_DICT
|
|
reshard_after_forward: true
|
|
```
|
|
|
|
## Best Practices {#sec-best-practices}
|
|
|
|
### Choosing Precision Format {#sec-choosing-format}
|
|
|
|
- **Start with automatic detection**: `bf16: auto`
|
|
- **For Hopper+ (H100/H200)**: Try FP8 + torch.compile for maximum speed
|
|
- **For Ampere (A100/RTX 30/40)**: Use BF16
|
|
- **For older Pascal/Turing GPUs**: Use FP16 with caution
|
|
- **For very old or unsupported GPUs**: Use FP32
|
|
|
|
### Validation and Testing {#sec-validation}
|
|
|
|
Always validate your mixed precision setup:
|
|
|
|
- **Start with a small dataset** to verify stability
|
|
- **Monitor loss curves** for irregularities
|
|
- **Compare with FP32 baseline** when possible
|
|
- **Test evaluation metrics** match expectations
|
|
|
|
### FP8 Particulars {#sec-fp8-details}
|
|
|
|
- Use cases
|
|
- Single GPU training
|
|
- Multi GPU training with FSDP2 or Deepspeed
|
|
- Speedups
|
|
- Please refer to the [TorchAO FP8 training benchmarks](https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling) for expected matmul speedups for different (M, K, N) settings
|
|
- Concrete number for LLaMA 3 8B training can be found [here](https://github.com/pytorch/ao/tree/main/torchao/float8#training-benchmarks)
|
|
- Known issues:
|
|
- FP8 + DDP + `torch.compile` (causes [error](https://gist.github.com/djsaunde/0c1664c32e44a64d31b5e01b4aafe5c4))
|
|
- FP8 + FSDP2 + `torch.compile` + FSDP2 activation checkpointing tends to be _slower_ than the BF16 equivalent training
|
|
- Flash Attention 2 does not play nicely with `torch.compile`
|
|
|
|
See `examples/llama-3/3b-fp8-fsdp2.yaml` for an optimized example config. Enabling FP8 mixed precision + FP8 all-gather training results in ~10% faster iterations per second vs. BF16 for a relatively small (3B param) model
|
|
|
|
For more information on multi-GPU training, see our [Multi-GPU guide](multi-gpu.qmd).
|