deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli * cleanup
This commit is contained in:
10
examples/cloud/baseten.yaml
Normal file
10
examples/cloud/baseten.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
provider: baseten
|
||||||
|
project_name:
|
||||||
|
|
||||||
|
secrets:
|
||||||
|
- HF_TOKEN
|
||||||
|
- WANDB_API_KEY
|
||||||
|
|
||||||
|
gpu: h100
|
||||||
|
gpu_count: 8
|
||||||
|
node_count: 1
|
||||||
@@ -7,6 +7,8 @@ from typing import Literal
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli.cloud.base import Cloud
|
||||||
|
from axolotl.cli.cloud.baseten import BasetenCloud
|
||||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -38,8 +40,15 @@ def do_cli_train(
|
|||||||
cwd=None,
|
cwd=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
provider = cloud_cfg.provider or "modal"
|
||||||
|
cloud: Cloud | None
|
||||||
|
if provider == "modal":
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
elif provider == "baseten":
|
||||||
|
cloud = BasetenCloud(cloud_cfg.to_dict())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported cloud provider: {provider}")
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
local_dirs = {}
|
local_dirs = {}
|
||||||
|
|||||||
48
src/axolotl/cli/cloud/baseten/__init__.py
Normal file
48
src/axolotl/cli/cloud/baseten/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Baseten Cloud CLI"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import subprocess # nosec B404
|
||||||
|
import tempfile
|
||||||
|
from os.path import dirname
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli.cloud.base import Cloud
|
||||||
|
|
||||||
|
|
||||||
|
class BasetenCloud(Cloud):
|
||||||
|
"""Baseten Cloud Axolotl CLI"""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Separate preprocess function for Baseten is not "
|
||||||
|
"implemented and will happen during hte train step."
|
||||||
|
)
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
config_yaml: str,
|
||||||
|
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||||
|
launcher_args: list[str] | None = None,
|
||||||
|
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = self.config.copy()
|
||||||
|
config["launcher"] = launcher
|
||||||
|
config["launcher_args"] = launcher_args
|
||||||
|
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
|
||||||
|
yaml.dump(config, cloud_fout)
|
||||||
|
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
|
||||||
|
config_fout.write(config_yaml)
|
||||||
|
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
|
||||||
|
shutil.copyfile(
|
||||||
|
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
|
||||||
|
)
|
||||||
|
subprocess.run( # nosec B603 B607
|
||||||
|
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
|
||||||
|
)
|
||||||
9
src/axolotl/cli/cloud/baseten/template/run.sh
Normal file
9
src/axolotl/cli/cloud/baseten/template/run.sh
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
export NCCL_SOCKET_IFNAME="^docker0,lo"
|
||||||
|
export NCCL_IB_DISABLE=0
|
||||||
|
export NCCL_TIMEOUT=1800000
|
||||||
|
|
||||||
|
axolotl preprocess train.yaml
|
||||||
|
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}
|
||||||
71
src/axolotl/cli/cloud/baseten/template/train_sft.py
Normal file
71
src/axolotl/cli/cloud/baseten/template/train_sft.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""
|
||||||
|
Baseten Training Script for Axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
import yaml
|
||||||
|
from truss.base import truss_config
|
||||||
|
|
||||||
|
# Import necessary classes from the Baseten Training SDK
|
||||||
|
from truss_train import definitions
|
||||||
|
|
||||||
|
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
|
||||||
|
gpu = cloud_config.get("gpu", "h100")
|
||||||
|
gpu_count = int(cloud_config.get("gpu_count", 1))
|
||||||
|
node_count = int(cloud_config.get("node_count", 1))
|
||||||
|
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
|
||||||
|
secrets = cloud_config.get("secrets", [])
|
||||||
|
launcher = cloud_config.get("launcher", "accelerate")
|
||||||
|
launcher_args = cloud_config.get("launcher_args", [])
|
||||||
|
script_name = "run.sh"
|
||||||
|
|
||||||
|
launcher_args_str = ""
|
||||||
|
if launcher_args:
|
||||||
|
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||||
|
|
||||||
|
# 1. Define a base image for your training job
|
||||||
|
# must use torch 2.7.0 for vllm
|
||||||
|
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
||||||
|
|
||||||
|
# 2. Define the Runtime Environment for the Training Job
|
||||||
|
# This includes start commands and environment variables.a
|
||||||
|
# Secrets from the baseten workspace like API keys are referenced using
|
||||||
|
# `SecretReference`.
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
"AXOLOTL_LAUNCHER": launcher,
|
||||||
|
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
|
||||||
|
}
|
||||||
|
for secret_name in secrets:
|
||||||
|
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
|
||||||
|
|
||||||
|
training_runtime = definitions.Runtime(
|
||||||
|
start_commands=[ # Example: list of commands to run your training script
|
||||||
|
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
||||||
|
],
|
||||||
|
environment_variables=env_vars,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Define the Compute Resources for the Training Job
|
||||||
|
training_compute = definitions.Compute(
|
||||||
|
node_count=node_count,
|
||||||
|
accelerator=truss_config.AcceleratorSpec(
|
||||||
|
accelerator=truss_config.Accelerator.H100,
|
||||||
|
count=gpu_count,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Define the Training Job
|
||||||
|
# This brings together the image, compute, and runtime configurations.
|
||||||
|
my_training_job = definitions.TrainingJob(
|
||||||
|
image=definitions.Image(base_image=BASE_IMAGE),
|
||||||
|
compute=training_compute,
|
||||||
|
runtime=training_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This config will be pushed using the Truss CLI.
|
||||||
|
# The association of the job to the project happens at the time of push.
|
||||||
|
first_project_with_job = definitions.TrainingProject(
|
||||||
|
name=project_name, job=my_training_job
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user