Sequence parallelism quick follow-ups; remove ModelCallback (#2450)

* guard return if ring attn alrady registered

* add docs link, bits in multi-gpu docs, remove save model callback (subsumed by HF trainers)

* configurable heads_k_stride from ring-flash-attn hf adapter
This commit is contained in:
Dan Saunders
2025-03-31 09:13:42 -04:00
committed by GitHub
parent cf0c79d52e
commit 5410195e0b
10 changed files with 56 additions and 31 deletions

View File

@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -66,6 +67,28 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Sequence parallelism {#sec-sequence-parallelism}
We support sequence parallelism (SP) via the
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
allows one to split up sequences across GPUs, which is useful in the event that a
single sequence causes OOM errors during model training.
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
or from source with `pip install .[ring-flash-attn]`.
Your Axolotl YAML config should contain the following lines:
```{.yaml}
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
flash_attention: true # Required with sequence parallelism
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
heads_k_stride: 1
```
See our [dedicated guide](sequence_parallelism.qmd) for more details.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).