native support for modal cloud from CLI (#2237)

* native support for modal cloud from CLI

* do lm_eval in cloud too

* Fix the sub call to lm-eval

* lm_eval option to not post eval, and append not extend

* cache bust when using branch, grab sha of latest image tag, update lm-eval dep

* allow minimal yaml for lm eval

* include modal in requirements

* update link in README to include utm

* pr feedback

* use chat template

* revision support

* apply chat template as arg

* add wandb name support, allow explicit a100-40gb

* cloud is optional

* handle accidental setting of tasks with a single task str

* document the modal cloud yaml for clarity [skip ci]

* cli docs

* support spawn vs remote for lm-eval

* Add support for additional docker commands in modal image build

* cloud config shouldn't be a dir

* Update README.md

Co-authored-by: Charles Frye <cfrye59@gmail.com>

* fix annotation args

---------

Co-authored-by: Charles Frye <cfrye59@gmail.com>
This commit is contained in:
Wing Lian
2025-01-30 11:34:02 -05:00
committed by GitHub
parent 268543a3be
commit 8779997ba5
12 changed files with 834 additions and 53 deletions

View File

@@ -0,0 +1,56 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
import yaml
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
return cloud_cfg
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.preprocess(config_yaml)
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,18 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
class Cloud(ABC):
"""
Abstract base class for cloud platforms.
"""
@abstractmethod
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -0,0 +1,282 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
import modal
from axolotl.cli.cloud.base import Cloud
def run_cmd(cmd: str, run_folder: str, volumes=None):
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
# Ensure volumes contain latest files.
if volumes:
for _, vol in volumes.items():
vol.reload()
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
for _, vol in volumes.items():
vol.commit()
class ModalCloud(Cloud):
"""
Modal Cloud implementation.
"""
def __init__(self, config, app=None):
self.config = config
if not app:
app = modal.App()
self.app = app
self.volumes = {}
if config.volumes:
for volume_config in config.volumes:
_, mount, vol = self.create_volume(volume_config)
self.volumes[mount] = (vol, volume_config)
def get_env(self):
res = {
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
}
for key in self.config.get("env", []):
if isinstance(key, str):
if val := os.environ.get(key, ""):
res[key] = val
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res[key_] = val
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:
sha256_hash = None
# create the image
if sha256_hash:
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
else:
image = modal.Image.from_registry(docker_image)
dockerfile_commands = []
if self.config.dockerfile_commands:
dockerfile_commands.extend(self.config.dockerfile_commands)
# branch
if self.config.branch:
dockerfile_commands.extend(
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
]
)
if dockerfile_commands:
image = image.dockerfile_commands(dockerfile_commands)
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res.append(modal.Secret.from_dict({key_: val}))
return res
def create_volume(self, volume_config):
name = volume_config.name
mount = volume_config.mount
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
def get_ephemeral_disk_size(self):
return 1000 * 525 # 1 TiB
def get_preprocess_timeout(self):
if self.config.timeout_preprocess:
return int(self.config.timeout_preprocess)
return 60 * 60 * 3 # 3 hours
def get_preprocess_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
if self.config.memory_preprocess:
memory = int(self.config.memory_preprocess)
return 1024 * memory
def get_preprocess_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=8.0,
ephemeral_disk=self.get_ephemeral_disk_size(),
memory=self.get_preprocess_memory(),
timeout=self.get_preprocess_timeout(),
secrets=self.get_secrets(),
)
def preprocess(self, config_yaml: str, *args, **kwargs):
modal_fn = self.get_preprocess_env()(_preprocess)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
def get_train_timeout(self):
if self.config.timeout:
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
if family == "l40s":
return modal.gpu.L40S(count=count)
if family in ["a100", "a100-40gb"]:
return modal.gpu.A100(count=count, size="40GB")
if family == "a100-80gb":
return modal.gpu.A100(count=count, size="80GB")
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
return modal.gpu.L4(count=count)
raise ValueError(f"Unsupported GPU family: {family}")
def get_train_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
memory=self.get_train_memory(),
timeout=self.get_train_timeout(),
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def lm_eval(self, config_yaml: str):
modal_fn = self.get_train_env()(_lm_eval)
with modal.enable_output():
with self.app.run(detach=True):
if self.config.get("spawn", False):
modal_fn_exec = modal_fn.spawn
else:
modal_fn_exec = modal_fn.remote
modal_fn_exec(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -15,6 +15,7 @@ from axolotl.cli.utils import (
fetch_from_github,
filter_none_kwargs,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -27,21 +28,28 @@ def cli():
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, **kwargs) -> None:
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
cloud: Path to a cloud accelerator configuration file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.preprocess import do_cli
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
do_cli(config=config, **kwargs)
do_cli_preprocess(cloud_config=cloud, config=config)
else:
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@@ -51,47 +59,56 @@ def preprocess(config: str, **kwargs) -> None:
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, **kwargs) -> None:
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if accelerate:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, **kwargs)
@cli.command()
@@ -210,7 +227,6 @@ def merge_lora(config: str, **kwargs) -> None:
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
@@ -237,6 +253,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
fetch_from_github(f"{directory}/", dest)
cli.add_command(lm_eval)
def main():
cli()

View File

@@ -2,9 +2,9 @@
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
@@ -18,25 +18,20 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks)
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
output_path = cfg.output_dir
output_path += "" if cfg.output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
subprocess.run( # nosec
[
"lm_eval",
"--model",
"hf",
"--model_args",
f"pretrained={cfg.output_dir}{fa2}{dtype}",
"--tasks",
tasks,
"--batch_size",
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -0,0 +1,119 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
flash_attention=False,
output_dir="./",
batch_size=8,
wandb_project=None,
wandb_entity=None,
wandb_name=None,
model=None,
revision=None,
apply_chat_template=None,
fewshot_as_multiturn=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
if isinstance(tasks, str):
tasks = [tasks]
for task in tasks:
num_fewshot = "-1"
task_parts = task.split(":")
task_name = task_parts[0]
if len(task_parts) == 2:
task_name, num_fewshot = task_parts
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
lm_eval_args = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
str(batch_size),
"--output_path",
output_path,
]
wandb_args = []
if wandb_project:
wandb_args.append(f"project={wandb_project}")
if wandb_entity:
wandb_args.append(f"entity={wandb_entity}")
if wandb_name:
wandb_args.append(f"name={wandb_name}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.append(",".join(wandb_args))
if apply_chat_template:
lm_eval_args.append("--apply_chat_template")
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))
if apply_chat_template and fewshot_as_multiturn:
lm_eval_args.append("--fewshot_as_multiturn")
yield lm_eval_args
@click.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
def lm_eval(config: str, cloud: Optional[str] = None):
"""
use lm eval to evaluate a trained language model
"""
if cloud:
from axolotl.cli.cloud import do_cli_lm_eval
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)