From 0e9945e3b91e853b36e97c0dbd29bfd778382511 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Aug 2025 09:29:50 -0400 Subject: [PATCH] deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci] * deploy training jobs to baseten w truss in axolotl cli * cleanup --- examples/cloud/baseten.yaml | 10 +++ src/axolotl/cli/cloud/__init__.py | 13 +++- src/axolotl/cli/cloud/baseten/__init__.py | 48 +++++++++++++ src/axolotl/cli/cloud/baseten/template/run.sh | 9 +++ .../cli/cloud/baseten/template/train_sft.py | 71 +++++++++++++++++++ 5 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 examples/cloud/baseten.yaml create mode 100644 src/axolotl/cli/cloud/baseten/__init__.py create mode 100644 src/axolotl/cli/cloud/baseten/template/run.sh create mode 100644 src/axolotl/cli/cloud/baseten/template/train_sft.py diff --git a/examples/cloud/baseten.yaml b/examples/cloud/baseten.yaml new file mode 100644 index 000000000..23c4b52d6 --- /dev/null +++ b/examples/cloud/baseten.yaml @@ -0,0 +1,10 @@ +provider: baseten +project_name: + +secrets: + - HF_TOKEN + - WANDB_API_KEY + +gpu: h100 +gpu_count: 8 +node_count: 1 diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index bf12ab8cb..60f6a51ce 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -7,6 +7,8 @@ from typing import Literal import yaml +from axolotl.cli.cloud.base import Cloud +from axolotl.cli.cloud.baseten import BasetenCloud from axolotl.cli.cloud.modal_ import ModalCloud from axolotl.utils.dict import DictDefault @@ -38,8 +40,15 @@ def do_cli_train( cwd=None, **kwargs, ) -> None: - cloud_cfg = load_cloud_cfg(cloud_config) - cloud = ModalCloud(cloud_cfg) + cloud_cfg: DictDefault = load_cloud_cfg(cloud_config) + provider = cloud_cfg.provider or "modal" + cloud: Cloud | None + if provider == "modal": + cloud = ModalCloud(cloud_cfg) + elif provider == "baseten": + cloud = BasetenCloud(cloud_cfg.to_dict()) + else: + raise ValueError(f"Unsupported cloud provider: {provider}") with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() local_dirs = {} diff --git a/src/axolotl/cli/cloud/baseten/__init__.py b/src/axolotl/cli/cloud/baseten/__init__.py new file mode 100644 index 000000000..914504de3 --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/__init__.py @@ -0,0 +1,48 @@ +"""Baseten Cloud CLI""" + +import shutil +import subprocess # nosec B404 +import tempfile +from os.path import dirname +from typing import Literal + +import yaml + +from axolotl.cli.cloud.base import Cloud + + +class BasetenCloud(Cloud): + """Baseten Cloud Axolotl CLI""" + + def __init__(self, config: dict): + self.config = config + + def preprocess(self, config_yaml: str, *args, **kwargs) -> None: + raise NotImplementedError( + "Separate preprocess function for Baseten is not " + "implemented and will happen during hte train step." + ) + + def train( + self, + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument + **kwargs, + ): + with tempfile.TemporaryDirectory() as tmp_dir: + config = self.config.copy() + config["launcher"] = launcher + config["launcher_args"] = launcher_args + with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout: + yaml.dump(config, cloud_fout) + with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout: + config_fout.write(config_yaml) + shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh") + shutil.copyfile( + dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py" + ) + subprocess.run( # nosec B603 B607 + ["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False + ) diff --git a/src/axolotl/cli/cloud/baseten/template/run.sh b/src/axolotl/cli/cloud/baseten/template/run.sh new file mode 100644 index 000000000..37dc9688f --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/run.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -eux + +export NCCL_SOCKET_IFNAME="^docker0,lo" +export NCCL_IB_DISABLE=0 +export NCCL_TIMEOUT=1800000 + +axolotl preprocess train.yaml +axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS} diff --git a/src/axolotl/cli/cloud/baseten/template/train_sft.py b/src/axolotl/cli/cloud/baseten/template/train_sft.py new file mode 100644 index 000000000..137fb9171 --- /dev/null +++ b/src/axolotl/cli/cloud/baseten/template/train_sft.py @@ -0,0 +1,71 @@ +""" +Baseten Training Script for Axolotl +""" + +# pylint: skip-file +import yaml +from truss.base import truss_config + +# Import necessary classes from the Baseten Training SDK +from truss_train import definitions + +cloud_config = yaml.safe_load(open("cloud.yaml", "r")) +gpu = cloud_config.get("gpu", "h100") +gpu_count = int(cloud_config.get("gpu_count", 1)) +node_count = int(cloud_config.get("node_count", 1)) +project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project" +secrets = cloud_config.get("secrets", []) +launcher = cloud_config.get("launcher", "accelerate") +launcher_args = cloud_config.get("launcher_args", []) +script_name = "run.sh" + +launcher_args_str = "" +if launcher_args: + launcher_args_str = "-- " + " ".join(launcher_args) + +# 1. Define a base image for your training job +# must use torch 2.7.0 for vllm +BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1" + +# 2. Define the Runtime Environment for the Training Job +# This includes start commands and environment variables.a +# Secrets from the baseten workspace like API keys are referenced using +# `SecretReference`. + +env_vars = { + "AXOLOTL_LAUNCHER": launcher, + "AXOLOTL_LAUNCHER_ARGS": launcher_args_str, +} +for secret_name in secrets: + env_vars[secret_name] = definitions.SecretReference(name=secret_name) + +training_runtime = definitions.Runtime( + start_commands=[ # Example: list of commands to run your training script + f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'" + ], + environment_variables=env_vars, +) + +# 3. Define the Compute Resources for the Training Job +training_compute = definitions.Compute( + node_count=node_count, + accelerator=truss_config.AcceleratorSpec( + accelerator=truss_config.Accelerator.H100, + count=gpu_count, + ), +) + +# 4. Define the Training Job +# This brings together the image, compute, and runtime configurations. +my_training_job = definitions.TrainingJob( + image=definitions.Image(base_image=BASE_IMAGE), + compute=training_compute, + runtime=training_runtime, +) + + +# This config will be pushed using the Truss CLI. +# The association of the job to the project happens at the time of push. +first_project_with_job = definitions.TrainingProject( + name=project_name, job=my_training_job +)