Compare commits
1 Commits
fix/diffus
...
feat/lmeva
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e37a768960 |
@@ -67,8 +67,16 @@ def do_cli_lm_eval(
|
|||||||
cloud_config: Path | str,
|
cloud_config: Path | str,
|
||||||
config: Path | str,
|
config: Path | str,
|
||||||
) -> None:
|
) -> None:
|
||||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
|
||||||
cloud = ModalCloud(cloud_cfg)
|
provider = cloud_cfg.provider or "modal"
|
||||||
|
cloud: Cloud | None
|
||||||
|
if provider == "modal":
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
elif provider == "baseten":
|
||||||
|
cloud = BasetenCloud(cloud_cfg.to_dict())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported cloud provider: {provider}")
|
||||||
|
|
||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
cloud.lm_eval(config_yaml)
|
cloud.lm_eval(config_yaml)
|
||||||
|
|||||||
@@ -46,3 +46,23 @@ class BasetenCloud(Cloud):
|
|||||||
subprocess.run( # nosec B603 B607
|
subprocess.run( # nosec B603 B607
|
||||||
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
|
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def lm_eval(
|
||||||
|
self,
|
||||||
|
config_yaml: str,
|
||||||
|
):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config = self.config.copy()
|
||||||
|
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
|
||||||
|
yaml.dump(config, cloud_fout)
|
||||||
|
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
|
||||||
|
config_fout.write(config_yaml)
|
||||||
|
shutil.copyfile(
|
||||||
|
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
|
||||||
|
)
|
||||||
|
shutil.copyfile(
|
||||||
|
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
|
||||||
|
)
|
||||||
|
subprocess.run( # nosec B603 B607
|
||||||
|
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
|
||||||
|
)
|
||||||
|
|||||||
8
src/axolotl/cli/cloud/baseten/template/eval.sh
Normal file
8
src/axolotl/cli/cloud/baseten/template/eval.sh
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -eux
|
||||||
|
|
||||||
|
export NCCL_SOCKET_IFNAME="^docker0,lo"
|
||||||
|
export NCCL_IB_DISABLE=0
|
||||||
|
export NCCL_TIMEOUT=1800000
|
||||||
|
|
||||||
|
axolotl lm-eval eval.yaml
|
||||||
81
src/axolotl/cli/cloud/baseten/template/eval_sft.py
Normal file
81
src/axolotl/cli/cloud/baseten/template/eval_sft.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
Baseten Training Script for Axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
import yaml
|
||||||
|
from truss.base import truss_config
|
||||||
|
|
||||||
|
# Import necessary classes from the Baseten Training SDK
|
||||||
|
from truss_train import definitions
|
||||||
|
|
||||||
|
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
|
||||||
|
gpu = cloud_config.get("gpu", "h100")
|
||||||
|
gpu_count = (
|
||||||
|
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
|
||||||
|
)
|
||||||
|
node_count = (
|
||||||
|
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
|
||||||
|
)
|
||||||
|
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
|
||||||
|
secrets = cloud_config.get("secrets", [])
|
||||||
|
# launcher = cloud_config.get("launcher", "accelerate")
|
||||||
|
# launcher_args = cloud_config.get("launcher_args", [])
|
||||||
|
script_name = "eval.sh"
|
||||||
|
|
||||||
|
# launcher_args_str = ""
|
||||||
|
# if launcher_args:
|
||||||
|
# launcher_args_str = "-- " + " ".join(launcher_args)
|
||||||
|
|
||||||
|
# 1. Define a base image for your training job
|
||||||
|
# must use torch 2.7.0 for vllm
|
||||||
|
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
|
||||||
|
|
||||||
|
# 2. Define the Runtime Environment for the Training Job
|
||||||
|
# This includes start commands and environment variables.a
|
||||||
|
# Secrets from the baseten workspace like API keys are referenced using
|
||||||
|
# `SecretReference`.
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
# "AXOLOTL_LAUNCHER": launcher,
|
||||||
|
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
|
||||||
|
}
|
||||||
|
for secret_name in secrets:
|
||||||
|
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
|
||||||
|
|
||||||
|
training_runtime = definitions.Runtime(
|
||||||
|
start_commands=[ # Example: list of commands to run your training script
|
||||||
|
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
||||||
|
],
|
||||||
|
environment_variables=env_vars,
|
||||||
|
cache_config=definitions.CacheConfig(
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
checkpointing_config=definitions.CheckpointingConfig(
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Define the Compute Resources for the Training Job
|
||||||
|
training_compute = definitions.Compute(
|
||||||
|
node_count=node_count,
|
||||||
|
accelerator=truss_config.AcceleratorSpec(
|
||||||
|
accelerator=truss_config.Accelerator.H100,
|
||||||
|
count=gpu_count,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Define the Training Job
|
||||||
|
# This brings together the image, compute, and runtime configurations.
|
||||||
|
my_training_job = definitions.TrainingJob(
|
||||||
|
image=definitions.Image(base_image=BASE_IMAGE),
|
||||||
|
compute=training_compute,
|
||||||
|
runtime=training_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This config will be pushed using the Truss CLI.
|
||||||
|
# The association of the job to the project happens at the time of push.
|
||||||
|
first_project_with_job = definitions.TrainingProject(
|
||||||
|
name=project_name, job=my_training_job
|
||||||
|
)
|
||||||
@@ -44,6 +44,12 @@ training_runtime = definitions.Runtime(
|
|||||||
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
|
||||||
],
|
],
|
||||||
environment_variables=env_vars,
|
environment_variables=env_vars,
|
||||||
|
cache_config=definitions.CacheConfig(
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
|
checkpointing_config=definitions.CheckpointingConfig(
|
||||||
|
enabled=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Define the Compute Resources for the Training Job
|
# 3. Define the Compute Resources for the Training Job
|
||||||
|
|||||||
Reference in New Issue
Block a user