Example for Slurm and various fixes (#3038) [skip ci]
* slurm example and make preprocess play nicely * start slurm if it init file exists * remove incorrect comment * feat: add slurm docs --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
66
examples/slurm/README.md
Normal file
66
examples/slurm/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# SLURM Multi-Node Training
|
||||
|
||||
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Access to a SLURM cluster with GPU nodes
|
||||
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
|
||||
|
||||
## Usage
|
||||
|
||||
### Standard SLURM Clusters
|
||||
|
||||
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
|
||||
2. Place your Axolotl config file (`train.yaml`) in the same directory.
|
||||
3. Set the appropriate environment variables for the job:
|
||||
```bash
|
||||
export HF_TOKEN="your-huggingface-token"
|
||||
|
||||
# metric tracking
|
||||
# export WANDB_API_KEY="your-wandb-api-key"
|
||||
# ...
|
||||
```
|
||||
4. Submit the job:
|
||||
```bash
|
||||
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
|
||||
```
|
||||
|
||||
Where:
|
||||
- `NUM_NODES`: Number of nodes to use
|
||||
- `NUM_TRAINERS`: GPUs per node (typically 8)
|
||||
- `PRIMARY_ADDR`: Hostname/IP of the master node
|
||||
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
|
||||
|
||||
5. (Optional) Run other slurm commands:
|
||||
```bash
|
||||
# check job info
|
||||
scontrol show job axolotl-cli
|
||||
|
||||
# check job queue
|
||||
squeue
|
||||
|
||||
# check cluster status
|
||||
sinfo
|
||||
```
|
||||
|
||||
### RunPod Instant Clusters
|
||||
|
||||
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
|
||||
|
||||
1. **Deploy a SLURM Cluster**:
|
||||
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
|
||||
- Click "Create a Cluster"
|
||||
- Choose your GPU type, node count, and region
|
||||
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
|
||||
- Deploy the cluster
|
||||
|
||||
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
|
||||
|
||||
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
|
||||
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)
|
||||
20
examples/slurm/axolotl.slurm
Normal file
20
examples/slurm/axolotl.slurm
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
|
||||
# export HF_TOKEN="..."
|
||||
# export WANDB_API_KEY="..."
|
||||
#
|
||||
|
||||
# ---------- SBATCH commands ---------- #
|
||||
#SBATCH --job-name=axolotl-slurm-multinode
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --nodes=$NUM_NODES
|
||||
#SBATCH --gpus-per-task=8
|
||||
#SBATCH --cpus-per-task=128
|
||||
|
||||
export TORCH_DIST_INIT_BARRIER=0
|
||||
|
||||
srun axolotl preprocess train.yaml
|
||||
|
||||
srun axolotl train train.yaml --launcher torchrun -- \
|
||||
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
|
||||
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"
|
||||
@@ -81,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
|
||||
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
|
||||
fi
|
||||
|
||||
# start the runpod slurm init
|
||||
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
|
||||
|
||||
if [[ -f "$SLURM_INIT" ]]; then
|
||||
echo "[entrypoint] running $SLURM_INIT..."
|
||||
bash "$SLURM_INIT"
|
||||
fi
|
||||
|
||||
# Execute the passed arguments (CMD)
|
||||
exec "$@"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Data handling specific to SFT."""
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Literal
|
||||
|
||||
@@ -104,6 +105,9 @@ def _prepare_standard_dataset(
|
||||
finally:
|
||||
loader.cleanup()
|
||||
|
||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
||||
return train_dataset, eval_dataset, -1, prompters
|
||||
|
||||
# Validate sample packing configuration for evaluation
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
|
||||
@@ -51,7 +51,10 @@ def init_distributed_state():
|
||||
global distributed_state # pylint: disable=global-statement
|
||||
if distributed_state is None:
|
||||
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
|
||||
try:
|
||||
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def get_distributed_state() -> PartialState | None:
|
||||
|
||||
Reference in New Issue
Block a user