--- title: "Multi-GPU" format: html: toc: true toc-depth: 3 # number-sections: true code-tools: true execute: enabled: false --- This guide covers advanced training configurations for multi-GPU setups using Axolotl. ## Overview {#sec-overview} When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy. You generally cannot combine these strategies; they are mutually exclusive. 1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3. 2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended). 3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected). These features can often be combined with the strategies above: * **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP). * **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP). ## DeepSpeed {#sec-deepspeed} ### Configuration {#sec-deepspeed-config} Add to your YAML config: ```{.yaml} deepspeed: deepspeed_configs/zero1.json ``` ### Usage {#sec-deepspeed-usage} ```{.bash} # Fetch deepspeed configs (if not already present) axolotl fetch deepspeed_configs # Passing arg via config axolotl train config.yml # Passing arg via cli axolotl train config.yml --deepspeed deepspeed_configs/zero1.json ``` ### ZeRO Stages {#sec-zero-stages} We provide default configurations for: - ZeRO Stage 1 (`zero1.json`) - ZeRO Stage 1 with torch compile (`zero1_torch_compile.json`) - ZeRO Stage 2 (`zero2.json`) - ZeRO Stage 3 (`zero3.json`) - ZeRO Stage 3 with bf16 (`zero3_bf16.json`) - ZeRO Stage 3 with bf16 and CPU offload params(`zero3_bf16_cpuoffload_params.json`) - ZeRO Stage 3 with bf16 and CPU offload params and optimizer (`zero3_bf16_cpuoffload_all.json`) ::: {.callout-tip} Choose the configuration that offloads the least amount to memory while still being able to fit on VRAM for best performance. Start from Stage 1 -> Stage 2 -> Stage 3. ::: ## Fully Sharded Data Parallel (FSDP) {#sec-fsdp} FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers. ::: {.callout-note} FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in an upcoming release of Axolotl. ::: ### FSDP + QLoRA {#sec-fsdp-qlora} For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd). ### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2} To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and also follow the config field mapping below to update field names. #### Config mapping FSDP1 | FSDP2 -------- | -------- fsdp_sharding_strategy | reshard_after_forward fsdp_backward_prefetch_policy | **REMOVED** fsdp_backward_prefetch | **REMOVED** fsdp_forward_prefetch | **REMOVED** fsdp_sync_module_states | **REMOVED** fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading fsdp_state_dict_type | state_dict_type fsdp_use_orig_params | **REMOVED** fsdp_activation_checkpointing | activation_checkpointing For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl, if you were using the following FSDP1 config: ```{.yaml} fsdp_version: 1 fsdp_config: fsdp_offload_params: false fsdp_cpu_ram_efficient_loading: true fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer fsdp_state_dict_type: FULL_STATE_DICT fsdp_sharding_strategy: FULL_SHARD ``` You can migrate to the following FSDP2 config: ```{.yaml} fsdp_version: 2 fsdp_config: offload_params: false cpu_ram_efficient_loading: true auto_wrap_policy: TRANSFORMER_BASED_WRAP transformer_layer_cls_to_wrap: Qwen3DecoderLayer state_dict_type: FULL_STATE_DICT reshard_after_forward: true ``` ### FSDP1 (deprecated) {#sec-fsdp-config} ::: {.callout-note} Using `fsdp` to configure FSDP is deprecated and will be removed in an upcoming release of Axolotl. Please use `fsdp_config` as above instead. ::: ```{.yaml} fsdp: - full_shard - auto_wrap fsdp_config: fsdp_offload_params: true fsdp_state_dict_type: FULL_STATE_DICT fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer ``` ## Sequence parallelism {#sec-sequence-parallelism} We support sequence parallelism (SP) via the [ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This allows one to split up sequences across GPUs, which is useful in the event that a single sequence causes OOM errors during model training. See our [dedicated guide](sequence_parallelism.qmd) for more information. ## Performance Optimization {#sec-performance} ### Liger Kernel Integration {#sec-liger} Please see [docs](custom_integrations.qmd#liger) for more info. ## Troubleshooting {#sec-troubleshooting} ### NCCL Issues {#sec-nccl} For NCCL-related problems, see our [NCCL troubleshooting guide](nccl.qmd). ### Common Problems {#sec-common-problems} ::: {.panel-tabset} ## Memory Issues - Reduce `micro_batch_size` - Reduce `eval_batch_size` - Adjust `gradient_accumulation_steps` - Consider using a higher ZeRO stage ## Training Instability - Start with DeepSpeed ZeRO-2 - Monitor loss values - Check learning rates ::: For more detailed troubleshooting, see our [debugging guide](debugging.qmd).