From eb2c87b525fd6767e3e09c0e6e6d4612f902263d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Aug 2025 08:02:03 -0400 Subject: [PATCH] Example for Slurm and various fixes (#3038) [skip ci] * slurm example and make preprocess play nicely * start slurm if it init file exists * remove incorrect comment * feat: add slurm docs --------- Co-authored-by: NanoCode012 --- examples/slurm/README.md | 66 ++++++++++++++++++++++++++++++++ examples/slurm/axolotl.slurm | 20 ++++++++++ scripts/cloud-entrypoint.sh | 8 ++++ src/axolotl/utils/data/sft.py | 4 ++ src/axolotl/utils/distributed.py | 5 ++- 5 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 examples/slurm/README.md create mode 100644 examples/slurm/axolotl.slurm diff --git a/examples/slurm/README.md b/examples/slurm/README.md new file mode 100644 index 000000000..4c116b713 --- /dev/null +++ b/examples/slurm/README.md @@ -0,0 +1,66 @@ +# SLURM Multi-Node Training + +This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster. + +## Prerequisites + +- Access to a SLURM cluster with GPU nodes +- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html)) + +## Usage + +### Standard SLURM Clusters + +1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory. +2. Place your Axolotl config file (`train.yaml`) in the same directory. +3. Set the appropriate environment variables for the job: + ```bash + export HF_TOKEN="your-huggingface-token" + + # metric tracking + # export WANDB_API_KEY="your-wandb-api-key" + # ... + ``` +4. Submit the job: + ```bash + sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=,PRIMARY_PORT=29400 axolotl.slurm + ``` + + Where: + - `NUM_NODES`: Number of nodes to use + - `NUM_TRAINERS`: GPUs per node (typically 8) + - `PRIMARY_ADDR`: Hostname/IP of the master node + - `PRIMARY_PORT`: Port for distributed training (default: 29400) + +5. (Optional) Run other slurm commands: + ```bash + # check job info + scontrol show job axolotl-cli + + # check job queue + squeue + + # check cluster status + sinfo + ``` + +### RunPod Instant Clusters + +Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration. + +1. **Deploy a SLURM Cluster**: + - Go to [RunPod Instant Clusters](https://console.runpod.io/cluster) + - Click "Create a Cluster" + - Choose your GPU type, node count, and region + - Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud) + - Deploy the cluster + +2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH + +3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)** + +## Additional Resources + +- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html) +- [SLURM Documentation](https://slurm.schedmd.com/documentation.html) +- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters) diff --git a/examples/slurm/axolotl.slurm b/examples/slurm/axolotl.slurm new file mode 100644 index 000000000..741d68ced --- /dev/null +++ b/examples/slurm/axolotl.slurm @@ -0,0 +1,20 @@ +#!/bin/bash +# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e. +# export HF_TOKEN="..." +# export WANDB_API_KEY="..." +# + +# ---------- SBATCH commands ---------- # +#SBATCH --job-name=axolotl-slurm-multinode +#SBATCH --ntasks-per-node=1 +#SBATCH --nodes=$NUM_NODES +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=128 + +export TORCH_DIST_INIT_BARRIER=0 + +srun axolotl preprocess train.yaml + +srun axolotl train train.yaml --launcher torchrun -- \ + --nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \ + --rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800" diff --git a/scripts/cloud-entrypoint.sh b/scripts/cloud-entrypoint.sh index a5505e9ad..c98e7c0d0 100755 --- a/scripts/cloud-entrypoint.sh +++ b/scripts/cloud-entrypoint.sh @@ -81,5 +81,13 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs fi +# start the runpod slurm init +SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}" + +if [[ -f "$SLURM_INIT" ]]; then + echo "[entrypoint] running $SLURM_INIT..." + bash "$SLURM_INIT" +fi + # Execute the passed arguments (CMD) exec "$@" diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 3189b29c3..975f26e71 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -1,6 +1,7 @@ """Data handling specific to SFT.""" import functools +import os import tempfile from typing import Literal @@ -104,6 +105,9 @@ def _prepare_standard_dataset( finally: loader.cleanup() + if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": + return train_dataset, eval_dataset, -1, prompters + # Validate sample packing configuration for evaluation if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) diff --git a/src/axolotl/utils/distributed.py b/src/axolotl/utils/distributed.py index d2d1075cb..48771fd97 100644 --- a/src/axolotl/utils/distributed.py +++ b/src/axolotl/utils/distributed.py @@ -51,7 +51,10 @@ def init_distributed_state(): global distributed_state # pylint: disable=global-statement if distributed_state is None: timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) - distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + try: + distributed_state = PartialState(timeout=timedelta(seconds=timeout)) + except ValueError: + pass def get_distributed_state() -> PartialState | None: