* guard return if ring attn alrady registered * add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers) * configurable heads_k_stride from ring-flash-attn hf adapter
128 lines
3.0 KiB
Plaintext
128 lines
3.0 KiB
Plaintext
---
|
|
title: "Multi-GPU"
|
|
format:
|
|
html:
|
|
toc: true
|
|
toc-depth: 3
|
|
number-sections: true
|
|
code-tools: true
|
|
execute:
|
|
enabled: false
|
|
---
|
|
|
|
This guide covers advanced training configurations for multi-GPU setups using Axolotl.
|
|
|
|
## Overview {#sec-overview}
|
|
|
|
Axolotl supports several methods for multi-GPU training:
|
|
|
|
- DeepSpeed (recommended)
|
|
- FSDP (Fully Sharded Data Parallel)
|
|
- Sequence parallelism
|
|
- FSDP + QLoRA
|
|
|
|
## DeepSpeed {#sec-deepspeed}
|
|
|
|
DeepSpeed is the recommended approach for multi-GPU training due to its stability and performance. It provides various optimization levels through ZeRO stages.
|
|
|
|
### Configuration {#sec-deepspeed-config}
|
|
|
|
Add to your YAML config:
|
|
|
|
```{.yaml}
|
|
deepspeed: deepspeed_configs/zero1.json
|
|
```
|
|
|
|
### Usage {#sec-deepspeed-usage}
|
|
|
|
```{.bash}
|
|
# Passing arg via config
|
|
axolotl train config.yml
|
|
|
|
# Passing arg via cli
|
|
axolotl train config.yml --deepspeed deepspeed_configs/zero1.json
|
|
```
|
|
|
|
### ZeRO Stages {#sec-zero-stages}
|
|
|
|
We provide default configurations for:
|
|
|
|
- ZeRO Stage 1 (`zero1.json`)
|
|
- ZeRO Stage 2 (`zero2.json`)
|
|
- ZeRO Stage 3 (`zero3.json`)
|
|
|
|
Choose based on your memory requirements and performance needs.
|
|
|
|
## FSDP {#sec-fsdp}
|
|
|
|
### Basic FSDP Configuration {#sec-fsdp-config}
|
|
|
|
```{.yaml}
|
|
fsdp:
|
|
- full_shard
|
|
- auto_wrap
|
|
fsdp_config:
|
|
fsdp_offload_params: true
|
|
fsdp_state_dict_type: FULL_STATE_DICT
|
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
```
|
|
|
|
## Sequence parallelism {#sec-sequence-parallelism}
|
|
|
|
We support sequence parallelism (SP) via the
|
|
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
|
allows one to split up sequences across GPUs, which is useful in the event that a
|
|
single sequence causes OOM errors during model training.
|
|
|
|
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
|
or from source with `pip install .[ring-flash-attn]`.
|
|
|
|
Your Axolotl YAML config should contain the following lines:
|
|
|
|
```{.yaml}
|
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
|
flash_attention: true # Required with sequence parallelism
|
|
|
|
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
|
heads_k_stride: 1
|
|
```
|
|
|
|
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
|
|
|
### FSDP + QLoRA {#sec-fsdp-qlora}
|
|
|
|
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
|
|
|
## Performance Optimization {#sec-performance}
|
|
|
|
### Liger Kernel Integration {#sec-liger}
|
|
|
|
Please see [docs](custom_integrations.qmd#liger) for more info.
|
|
|
|
## Troubleshooting {#sec-troubleshooting}
|
|
|
|
### NCCL Issues {#sec-nccl}
|
|
|
|
For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd).
|
|
|
|
### Common Problems {#sec-common-problems}
|
|
|
|
::: {.panel-tabset}
|
|
|
|
## Memory Issues
|
|
|
|
- Reduce `micro_batch_size`
|
|
- Reduce `eval_batch_size`
|
|
- Adjust `gradient_accumulation_steps`
|
|
- Consider using a higher ZeRO stage
|
|
|
|
## Training Instability
|
|
|
|
- Start with DeepSpeed ZeRO-2
|
|
- Monitor loss values
|
|
- Check learning rates
|
|
|
|
:::
|
|
|
|
For more detailed troubleshooting, see our [debugging guide](debugging.qmd).
|