native support for modal cloud from CLI (#2237)
* native support for modal cloud from CLI * do lm_eval in cloud too * Fix the sub call to lm-eval * lm_eval option to not post eval, and append not extend * cache bust when using branch, grab sha of latest image tag, update lm-eval dep * allow minimal yaml for lm eval * include modal in requirements * update link in README to include utm * pr feedback * use chat template * revision support * apply chat template as arg * add wandb name support, allow explicit a100-40gb * cloud is optional * handle accidental setting of tasks with a single task str * document the modal cloud yaml for clarity [skip ci] * cli docs * support spawn vs remote for lm-eval * Add support for additional docker commands in modal image build * cloud config shouldn't be a dir * Update README.md Co-authored-by: Charles Frye <cfrye59@gmail.com> * fix annotation args --------- Co-authored-by: Charles Frye <cfrye59@gmail.com>
This commit is contained in:
@@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to
|
||||
|
||||
---
|
||||
|
||||
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune large language models, run protein folding simulations, and much more.
|
||||
|
||||
---
|
||||
|
||||
|
||||
256
docs/cli.qmd
Normal file
256
docs/cli.qmd
Normal file
@@ -0,0 +1,256 @@
|
||||
# Axolotl CLI Documentation
|
||||
|
||||
The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers
|
||||
the CLI commands, their usage, and common examples.
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- Basic Commands
|
||||
- Command Reference
|
||||
- fetch
|
||||
- preprocess
|
||||
- train
|
||||
- inference
|
||||
- merge-lora
|
||||
- merge-sharded-fsdp-weights
|
||||
- evaluate
|
||||
- lm-eval
|
||||
- Legacy CLI Usage
|
||||
- Remote Compute with Modal Cloud
|
||||
- Cloud Configuration
|
||||
- Running on Modal Cloud
|
||||
- Cloud Configuration Options
|
||||
|
||||
|
||||
### Basic Commands
|
||||
|
||||
All Axolotl commands follow this general structure:
|
||||
|
||||
```bash
|
||||
axolotl <command> [config.yml] [options]
|
||||
```
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Command Reference
|
||||
|
||||
#### fetch
|
||||
|
||||
Downloads example configurations and deepspeed configs to your local machine.
|
||||
|
||||
```bash
|
||||
# Get example YAML files
|
||||
axolotl fetch examples
|
||||
|
||||
# Get deepspeed config files
|
||||
axolotl fetch deepspeed_configs
|
||||
|
||||
# Specify custom destination
|
||||
axolotl fetch examples --dest path/to/folder
|
||||
```
|
||||
|
||||
#### preprocess
|
||||
|
||||
Preprocesses and tokenizes your dataset before training. This is recommended for large datasets.
|
||||
|
||||
```bash
|
||||
# Basic preprocessing
|
||||
axolotl preprocess config.yml
|
||||
|
||||
# Preprocessing with one GPU
|
||||
CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml
|
||||
|
||||
# Debug mode to see processed examples
|
||||
axolotl preprocess config.yml --debug
|
||||
|
||||
# Debug with limited examples
|
||||
axolotl preprocess config.yml --debug --debug-num-examples 5
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
dataset_prepared_path: Local folder for saving preprocessed data
|
||||
push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional)
|
||||
```
|
||||
|
||||
#### train
|
||||
|
||||
Trains or fine-tunes a model using the configuration specified in your YAML file.
|
||||
|
||||
```bash
|
||||
# Basic training
|
||||
axolotl train config.yml
|
||||
|
||||
# Train and set/override specific options
|
||||
axolotl train config.yml \
|
||||
--learning-rate 1e-4 \
|
||||
--micro-batch-size 2 \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
```
|
||||
|
||||
#### inference
|
||||
|
||||
Runs inference using your trained model in either CLI or Gradio interface mode.
|
||||
|
||||
```bash
|
||||
# CLI inference with LoRA
|
||||
axolotl inference config.yml --lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# CLI inference with full model
|
||||
axolotl inference config.yml --base-model="./completed-model"
|
||||
|
||||
# Gradio web interface
|
||||
axolotl inference config.yml --gradio \
|
||||
--lora-model-dir="./outputs/lora-out"
|
||||
|
||||
# Inference with input from file
|
||||
cat prompt.txt | axolotl inference config.yml \
|
||||
--base-model="./completed-model"
|
||||
```
|
||||
|
||||
#### merge-lora
|
||||
|
||||
Merges trained LoRA adapters into the base model.
|
||||
|
||||
```bash
|
||||
# Basic merge
|
||||
axolotl merge-lora config.yml
|
||||
|
||||
# Specify LoRA directory (usually used with checkpoints)
|
||||
axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100"
|
||||
|
||||
# Merge using CPU (if out of GPU memory)
|
||||
CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
gpu_memory_limit: Limit GPU memory usage
|
||||
lora_on_cpu: Load LoRA weights on CPU
|
||||
```
|
||||
|
||||
#### merge-sharded-fsdp-weights
|
||||
|
||||
Merges sharded FSDP model checkpoints into a single combined checkpoint.
|
||||
|
||||
```bash
|
||||
# Basic merge
|
||||
axolotl merge-sharded-fsdp-weights config.yml
|
||||
```
|
||||
|
||||
#### evaluate
|
||||
|
||||
Evaluates a model's performance using metrics specified in the config.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
```
|
||||
|
||||
#### lm-eval
|
||||
|
||||
Runs LM Evaluation Harness on your model.
|
||||
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl lm-eval config.yml
|
||||
|
||||
# Evaluate specific tasks
|
||||
axolotl lm-eval config.yml --tasks arc_challenge,hellaswag
|
||||
```
|
||||
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_tasks: List of tasks to evaluate
|
||||
lm_eval_batch_size: Batch size for evaluation
|
||||
output_dir: Directory to save evaluation results
|
||||
```
|
||||
|
||||
### Legacy CLI Usage
|
||||
|
||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||
|
||||
```bash
|
||||
# Preprocess
|
||||
python -m axolotl.cli.preprocess config.yml
|
||||
|
||||
# Train
|
||||
accelerate launch -m axolotl.cli.train config.yml
|
||||
|
||||
# Inference
|
||||
accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out"
|
||||
|
||||
# Gradio interface
|
||||
accelerate launch -m axolotl.cli.inference config.yml \
|
||||
--lora_model_dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Remote Compute with Modal Cloud
|
||||
|
||||
Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a
|
||||
cloud YAML file alongside your regular Axolotl config.
|
||||
|
||||
#### Cloud Configuration
|
||||
|
||||
Create a cloud config YAML with your Modal settings:
|
||||
|
||||
```yaml
|
||||
# cloud_config.yml
|
||||
provider: modal
|
||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||
gpu_count: 1 # Number of GPUs to use
|
||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||
branch: main # Git branch to use (optional)
|
||||
|
||||
volumes: # Persistent storage volumes
|
||||
- name: axolotl-cache
|
||||
mount: /workspace/cache
|
||||
|
||||
env: # Environment variables
|
||||
- WANDB_API_KEY
|
||||
- HF_TOKEN
|
||||
```
|
||||
|
||||
#### Running on Modal Cloud
|
||||
|
||||
Commands that support the --cloud flag:
|
||||
|
||||
```bash
|
||||
# Preprocess on cloud
|
||||
axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
#### Cloud Configuration Options
|
||||
|
||||
```yaml
|
||||
provider: compute provider, currently only `modal` is supported
|
||||
gpu: GPU type to use
|
||||
gpu_count: Number of GPUs (default: 1)
|
||||
memory: RAM in GB (default: 128)
|
||||
timeout: Maximum runtime in seconds
|
||||
timeout_preprocess: Preprocessing timeout
|
||||
branch: Git branch to use
|
||||
docker_tag: Custom Docker image tag
|
||||
volumes: List of persistent storage volumes
|
||||
env: Environment variables to pass
|
||||
secrets: Secrets to inject
|
||||
```
|
||||
28
examples/cloud/modal.yaml
Normal file
28
examples/cloud/modal.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
project_name:
|
||||
volumes:
|
||||
- name: axolotl-data
|
||||
mount: /workspace/data
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
|
||||
# environment variables from local to set as secrets
|
||||
secrets:
|
||||
- HF_TOKEN
|
||||
- WANDB_API_KEY
|
||||
|
||||
# Which branch of axolotl to use remotely
|
||||
branch:
|
||||
|
||||
# additional custom commands when building the image
|
||||
dockerfile_commands:
|
||||
|
||||
gpu: h100
|
||||
gpu_count: 1
|
||||
|
||||
# Train specific configurations
|
||||
memory: 128
|
||||
timeout: 86400
|
||||
|
||||
# Preprocess specific configurations
|
||||
memory_preprocess: 32
|
||||
timeout_preprocess: 14400
|
||||
@@ -25,6 +25,7 @@ hf_transfer
|
||||
sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
|
||||
17
scripts/motd
17
scripts/motd
@@ -1,10 +1,15 @@
|
||||
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
|
||||
56
src/axolotl/cli/cloud/__init__.py
Normal file
56
src/axolotl/cli/cloud/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
launch axolotl in supported cloud platforms
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
return cloud_cfg
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.preprocess(config_yaml)
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.lm_eval(config_yaml)
|
||||
18
src/axolotl/cli/cloud/base.py
Normal file
18
src/axolotl/cli/cloud/base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
base class for cloud platforms from cli
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
"""
|
||||
Abstract base class for cloud platforms.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
pass
|
||||
282
src/axolotl/cli/cloud/modal_.py
Normal file
282
src/axolotl/cli/cloud/modal_.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Modal Cloud support from CLI
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
import modal
|
||||
|
||||
from axolotl.cli.cloud.base import Cloud
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
|
||||
# Ensure volumes contain latest files.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.reload()
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
if "PYTHONPATH" in new_env:
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
cmd.split(), cwd=run_folder, env=new_env
|
||||
):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
# Commit writes to volume.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.commit()
|
||||
|
||||
|
||||
class ModalCloud(Cloud):
|
||||
"""
|
||||
Modal Cloud implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, config, app=None):
|
||||
self.config = config
|
||||
if not app:
|
||||
app = modal.App()
|
||||
self.app = app
|
||||
|
||||
self.volumes = {}
|
||||
if config.volumes:
|
||||
for volume_config in config.volumes:
|
||||
_, mount, vol = self.create_volume(volume_config)
|
||||
self.volumes[mount] = (vol, volume_config)
|
||||
|
||||
def get_env(self):
|
||||
res = {
|
||||
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
|
||||
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
for key in self.config.get("env", []):
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res[key] = val
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res[key_] = val
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu124-2.5.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
# grab the sha256 hash from docker hub for this image+tag
|
||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||
try:
|
||||
manifest = subprocess.check_output( # nosec B602
|
||||
f"docker manifest inspect {docker_image}",
|
||||
shell=True,
|
||||
).decode("utf-8")
|
||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||
except subprocess.CalledProcessError:
|
||||
sha256_hash = None
|
||||
|
||||
# create the image
|
||||
if sha256_hash:
|
||||
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
|
||||
else:
|
||||
image = modal.Image.from_registry(docker_image)
|
||||
|
||||
dockerfile_commands = []
|
||||
if self.config.dockerfile_commands:
|
||||
dockerfile_commands.extend(self.config.dockerfile_commands)
|
||||
|
||||
# branch
|
||||
if self.config.branch:
|
||||
dockerfile_commands.extend(
|
||||
[
|
||||
# Random id for cache busting of branch commits
|
||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||
]
|
||||
)
|
||||
|
||||
if dockerfile_commands:
|
||||
image = image.dockerfile_commands(dockerfile_commands)
|
||||
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
res = []
|
||||
if self.config.secrets:
|
||||
for key in self.config.get("secrets", []):
|
||||
# pylint: disable=duplicate-code
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res.append(modal.Secret.from_dict({key: val}))
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res.append(modal.Secret.from_dict({key_: val}))
|
||||
return res
|
||||
|
||||
def create_volume(self, volume_config):
|
||||
name = volume_config.name
|
||||
mount = volume_config.mount
|
||||
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
|
||||
|
||||
def get_ephemeral_disk_size(self):
|
||||
return 1000 * 525 # 1 TiB
|
||||
|
||||
def get_preprocess_timeout(self):
|
||||
if self.config.timeout_preprocess:
|
||||
return int(self.config.timeout_preprocess)
|
||||
return 60 * 60 * 3 # 3 hours
|
||||
|
||||
def get_preprocess_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
if self.config.memory_preprocess:
|
||||
memory = int(self.config.memory_preprocess)
|
||||
return 1024 * memory
|
||||
|
||||
def get_preprocess_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=8.0,
|
||||
ephemeral_disk=self.get_ephemeral_disk_size(),
|
||||
memory=self.get_preprocess_memory(),
|
||||
timeout=self.get_preprocess_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs):
|
||||
modal_fn = self.get_preprocess_env()(_preprocess)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_train_timeout(self):
|
||||
if self.config.timeout:
|
||||
return int(self.config.timeout)
|
||||
return 60 * 60 * 24 # 24 hours
|
||||
|
||||
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
||||
count = self.config.gpu_count or 1
|
||||
family = self.config.gpu.lower() or "l40s"
|
||||
|
||||
if family == "l40s":
|
||||
return modal.gpu.L40S(count=count)
|
||||
if family in ["a100", "a100-40gb"]:
|
||||
return modal.gpu.A100(count=count, size="40GB")
|
||||
if family == "a100-80gb":
|
||||
return modal.gpu.A100(count=count, size="80GB")
|
||||
if family in ["a10", "a10g"]:
|
||||
return modal.gpu.A10G(count=count)
|
||||
if family == "h100":
|
||||
return modal.gpu.H100(count=count)
|
||||
if family == "t4":
|
||||
return modal.gpu.T4(count=count)
|
||||
if family == "l4":
|
||||
return modal.gpu.L4(count=count)
|
||||
raise ValueError(f"Unsupported GPU family: {family}")
|
||||
|
||||
def get_train_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
memory=self.get_train_memory(),
|
||||
timeout=self.get_train_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
modal_fn = self.get_train_env()(_lm_eval)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
if self.config.get("spawn", False):
|
||||
modal_fn_exec = modal_fn.spawn
|
||||
else:
|
||||
modal_fn_exec = modal_fn.remote
|
||||
modal_fn_exec(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
|
||||
def _preprocess(config_yaml: str, volumes=None):
|
||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
@@ -15,6 +15,7 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
@@ -27,21 +28,28 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
cloud: Path to a cloud accelerator configuration file.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli_preprocess(cloud_config=cloud, config=config)
|
||||
else:
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -51,47 +59,56 @@ def preprocess(config: str, **kwargs) -> None:
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def train(config: str, accelerate: bool, **kwargs) -> None:
|
||||
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
|
||||
if accelerate:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num-processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -210,7 +227,6 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
|
||||
Args:
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
@@ -237,6 +253,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
@@ -18,25 +18,20 @@ class LMEvalPlugin(BasePlugin):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
if cfg.lm_eval_post_train:
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
wandb_name=cfg.wandb_name,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
|
||||
@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
lm_eval_post_train: Optional[bool] = True
|
||||
lm_eval_model: Optional[str] = None
|
||||
|
||||
119
src/axolotl/integrations/lm_eval/cli.py
Normal file
119
src/axolotl/integrations/lm_eval/cli.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
axolotl CLI for running lm_eval tasks
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def build_lm_eval_command(
|
||||
tasks: list[str],
|
||||
bfloat16=True,
|
||||
flash_attention=False,
|
||||
output_dir="./",
|
||||
batch_size=8,
|
||||
wandb_project=None,
|
||||
wandb_entity=None,
|
||||
wandb_name=None,
|
||||
model=None,
|
||||
revision=None,
|
||||
apply_chat_template=None,
|
||||
fewshot_as_multiturn=None,
|
||||
):
|
||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
for task in tasks:
|
||||
num_fewshot = "-1"
|
||||
task_parts = task.split(":")
|
||||
task_name = task_parts[0]
|
||||
if len(task_parts) == 2:
|
||||
task_name, num_fewshot = task_parts
|
||||
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
||||
|
||||
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
||||
tasks_str = ",".join(tasks_list)
|
||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||
pretrained = "pretrained="
|
||||
pretrained += model if model else output_dir
|
||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||
revision = f",revision={revision}" if revision else ""
|
||||
output_path = output_dir
|
||||
output_path += "" if output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
lm_eval_args = [
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"{pretrained}{fa2}{dtype}{revision}",
|
||||
"--tasks",
|
||||
tasks_str,
|
||||
"--batch_size",
|
||||
str(batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
]
|
||||
wandb_args = []
|
||||
if wandb_project:
|
||||
wandb_args.append(f"project={wandb_project}")
|
||||
if wandb_entity:
|
||||
wandb_args.append(f"entity={wandb_entity}")
|
||||
if wandb_name:
|
||||
wandb_args.append(f"name={wandb_name}")
|
||||
if wandb_args:
|
||||
lm_eval_args.append("--wandb_args")
|
||||
lm_eval_args.append(",".join(wandb_args))
|
||||
if apply_chat_template:
|
||||
lm_eval_args.append("--apply_chat_template")
|
||||
if num_fewshot_val:
|
||||
lm_eval_args.append("--num_fewshot")
|
||||
lm_eval_args.append(str(num_fewshot_val))
|
||||
if apply_chat_template and fewshot_as_multiturn:
|
||||
lm_eval_args.append("--fewshot_as_multiturn")
|
||||
|
||||
yield lm_eval_args
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
"""
|
||||
use lm eval to evaluate a trained language model
|
||||
"""
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_lm_eval
|
||||
|
||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||
else:
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
wandb_name=cfg.wandb_name,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
revision=cfg.revision,
|
||||
apply_chat_template=cfg.apply_chat_template,
|
||||
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
Reference in New Issue
Block a user