diff --git a/README.md b/README.md index 6ee3237e7..674a9dcf1 100644 --- a/README.md +++ b/README.md @@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to --- -- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more. +- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune large language models, run protein folding simulations, and much more. --- diff --git a/docs/cli.qmd b/docs/cli.qmd new file mode 100644 index 000000000..5b494ab5d --- /dev/null +++ b/docs/cli.qmd @@ -0,0 +1,256 @@ +# Axolotl CLI Documentation + +The Axolotl CLI provides a streamlined interface for training and fine-tuning large language models. This guide covers +the CLI commands, their usage, and common examples. + +### Table of Contents + +- Basic Commands +- Command Reference + - fetch + - preprocess + - train + - inference + - merge-lora + - merge-sharded-fsdp-weights + - evaluate + - lm-eval +- Legacy CLI Usage +- Remote Compute with Modal Cloud + - Cloud Configuration + - Running on Modal Cloud + - Cloud Configuration Options + + +### Basic Commands + +All Axolotl commands follow this general structure: + +```bash +axolotl [config.yml] [options] +``` + +The config file can be local or a URL to a raw YAML file. + +### Command Reference + +#### fetch + +Downloads example configurations and deepspeed configs to your local machine. + +```bash +# Get example YAML files +axolotl fetch examples + +# Get deepspeed config files +axolotl fetch deepspeed_configs + +# Specify custom destination +axolotl fetch examples --dest path/to/folder +``` + +#### preprocess + +Preprocesses and tokenizes your dataset before training. This is recommended for large datasets. + +```bash +# Basic preprocessing +axolotl preprocess config.yml + +# Preprocessing with one GPU +CUDA_VISIBLE_DEVICES="0" axolotl preprocess config.yml + +# Debug mode to see processed examples +axolotl preprocess config.yml --debug + +# Debug with limited examples +axolotl preprocess config.yml --debug --debug-num-examples 5 +``` + +Configuration options: + +```yaml +dataset_prepared_path: Local folder for saving preprocessed data +push_dataset_to_hub: HuggingFace repo to push preprocessed data (optional) +``` + +#### train + +Trains or fine-tunes a model using the configuration specified in your YAML file. + +```bash +# Basic training +axolotl train config.yml + +# Train and set/override specific options +axolotl train config.yml \ + --learning-rate 1e-4 \ + --micro-batch-size 2 \ + --num-epochs 3 + +# Training without accelerate +axolotl train config.yml --no-accelerate + +# Resume training from checkpoint +axolotl train config.yml --resume-from-checkpoint path/to/checkpoint +``` + +#### inference + +Runs inference using your trained model in either CLI or Gradio interface mode. + +```bash +# CLI inference with LoRA +axolotl inference config.yml --lora-model-dir="./outputs/lora-out" + +# CLI inference with full model +axolotl inference config.yml --base-model="./completed-model" + +# Gradio web interface +axolotl inference config.yml --gradio \ + --lora-model-dir="./outputs/lora-out" + +# Inference with input from file +cat prompt.txt | axolotl inference config.yml \ + --base-model="./completed-model" +``` + +#### merge-lora + +Merges trained LoRA adapters into the base model. + +```bash +# Basic merge +axolotl merge-lora config.yml + +# Specify LoRA directory (usually used with checkpoints) +axolotl merge-lora config.yml --lora-model-dir="./lora-output/checkpoint-100" + +# Merge using CPU (if out of GPU memory) +CUDA_VISIBLE_DEVICES="" axolotl merge-lora config.yml +``` + +Configuration options: + +```yaml +gpu_memory_limit: Limit GPU memory usage +lora_on_cpu: Load LoRA weights on CPU +``` + +#### merge-sharded-fsdp-weights + +Merges sharded FSDP model checkpoints into a single combined checkpoint. + +```bash +# Basic merge +axolotl merge-sharded-fsdp-weights config.yml +``` + +#### evaluate + +Evaluates a model's performance using metrics specified in the config. + +```bash +# Basic evaluation +axolotl evaluate config.yml +``` + +#### lm-eval + +Runs LM Evaluation Harness on your model. + +```bash +# Basic evaluation +axolotl lm-eval config.yml + +# Evaluate specific tasks +axolotl lm-eval config.yml --tasks arc_challenge,hellaswag +``` + +Configuration options: + +```yaml +lm_eval_tasks: List of tasks to evaluate +lm_eval_batch_size: Batch size for evaluation +output_dir: Directory to save evaluation results +``` + +### Legacy CLI Usage + +While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI: + +```bash +# Preprocess +python -m axolotl.cli.preprocess config.yml + +# Train +accelerate launch -m axolotl.cli.train config.yml + +# Inference +accelerate launch -m axolotl.cli.inference config.yml \ + --lora_model_dir="./outputs/lora-out" + +# Gradio interface +accelerate launch -m axolotl.cli.inference config.yml \ + --lora_model_dir="./outputs/lora-out" --gradio +``` + +### Remote Compute with Modal Cloud + +Axolotl supports running training and inference workloads on Modal cloud infrastructure. This is configured using a +cloud YAML file alongside your regular Axolotl config. + +#### Cloud Configuration + +Create a cloud config YAML with your Modal settings: + +```yaml +# cloud_config.yml +provider: modal +gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4 +gpu_count: 1 # Number of GPUs to use +timeout: 86400 # Maximum runtime in seconds (24 hours) +branch: main # Git branch to use (optional) + +volumes: # Persistent storage volumes + - name: axolotl-cache + mount: /workspace/cache + +env: # Environment variables + - WANDB_API_KEY + - HF_TOKEN +``` + +#### Running on Modal Cloud + +Commands that support the --cloud flag: + +```bash +# Preprocess on cloud +axolotl preprocess config.yml --cloud cloud_config.yml + +# Train on cloud +axolotl train config.yml --cloud cloud_config.yml + +# Train without accelerate on cloud +axolotl train config.yml --cloud cloud_config.yml --no-accelerate + +# Run lm-eval on cloud +axolotl lm-eval config.yml --cloud cloud_config.yml +``` + +#### Cloud Configuration Options + +```yaml +provider: compute provider, currently only `modal` is supported +gpu: GPU type to use +gpu_count: Number of GPUs (default: 1) +memory: RAM in GB (default: 128) +timeout: Maximum runtime in seconds +timeout_preprocess: Preprocessing timeout +branch: Git branch to use +docker_tag: Custom Docker image tag +volumes: List of persistent storage volumes +env: Environment variables to pass +secrets: Secrets to inject +``` diff --git a/examples/cloud/modal.yaml b/examples/cloud/modal.yaml new file mode 100644 index 000000000..195031494 --- /dev/null +++ b/examples/cloud/modal.yaml @@ -0,0 +1,28 @@ +project_name: +volumes: + - name: axolotl-data + mount: /workspace/data + - name: axolotl-artifacts + mount: /workspace/artifacts + +# environment variables from local to set as secrets +secrets: + - HF_TOKEN + - WANDB_API_KEY + +# Which branch of axolotl to use remotely +branch: + +# additional custom commands when building the image +dockerfile_commands: + +gpu: h100 +gpu_count: 1 + +# Train specific configurations +memory: 128 +timeout: 86400 + +# Preprocess specific configurations +memory_preprocess: 32 +timeout_preprocess: 14400 diff --git a/requirements.txt b/requirements.txt index 446fa94a6..061229c69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ hf_transfer sentencepiece gradio==3.50.2 +modal==0.70.5 pydantic==2.6.3 addict fire diff --git a/scripts/motd b/scripts/motd index b3ffa165e..bc123c312 100644 --- a/scripts/motd +++ b/scripts/motd @@ -1,10 +1,15 @@ - dP dP dP - 88 88 88 - .d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88 - 88' `88 `8bd8' 88' `88 88 88' `88 88 88 - 88. .88 .d88b. 88. .88 88 88. .88 88 88 - `88888P8 dP' `dP `88888P' dP `88888P' dP dP + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands: diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py new file mode 100644 index 000000000..fde46e397 --- /dev/null +++ b/src/axolotl/cli/cloud/__init__.py @@ -0,0 +1,56 @@ +""" +launch axolotl in supported cloud platforms +""" +from pathlib import Path +from typing import Union + +import yaml + +from axolotl.cli.art import print_axolotl_text_art +from axolotl.cli.cloud.modal_ import ModalCloud +from axolotl.utils.dict import DictDefault + + +def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault: + """Load and validate cloud configuration.""" + # Load cloud configuration. + with open(cloud_config, encoding="utf-8") as file: + cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file)) + return cloud_cfg + + +def do_cli_preprocess( + cloud_config: Union[Path, str], + config: Union[Path, str], +) -> None: + print_axolotl_text_art() + cloud_cfg = load_cloud_cfg(cloud_config) + cloud = ModalCloud(cloud_cfg) + with open(config, "r", encoding="utf-8") as file: + config_yaml = file.read() + cloud.preprocess(config_yaml) + + +def do_cli_train( + cloud_config: Union[Path, str], + config: Union[Path, str], + accelerate: bool = True, +) -> None: + print_axolotl_text_art() + cloud_cfg = load_cloud_cfg(cloud_config) + cloud = ModalCloud(cloud_cfg) + with open(config, "r", encoding="utf-8") as file: + config_yaml = file.read() + cloud.train(config_yaml, accelerate=accelerate) + + +def do_cli_lm_eval( + cloud_config: Union[Path, str], + config: Union[Path, str], +) -> None: + print_axolotl_text_art() + cloud_cfg = load_cloud_cfg(cloud_config) + cloud = ModalCloud(cloud_cfg) + with open(config, "r", encoding="utf-8") as file: + config_yaml = file.read() + cloud.lm_eval(config_yaml) diff --git a/src/axolotl/cli/cloud/base.py b/src/axolotl/cli/cloud/base.py new file mode 100644 index 000000000..44d1b0c17 --- /dev/null +++ b/src/axolotl/cli/cloud/base.py @@ -0,0 +1,18 @@ +""" +base class for cloud platforms from cli +""" +from abc import ABC, abstractmethod + + +class Cloud(ABC): + """ + Abstract base class for cloud platforms. + """ + + @abstractmethod + def preprocess(self, config_yaml: str, *args, **kwargs) -> None: + pass + + @abstractmethod + def train(self, config_yaml: str, accelerate: bool = True) -> str: + pass diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py new file mode 100644 index 000000000..bcc47ead9 --- /dev/null +++ b/src/axolotl/cli/cloud/modal_.py @@ -0,0 +1,282 @@ +""" +Modal Cloud support from CLI +""" +import copy +import json +import os +import subprocess # nosec B404 +from pathlib import Path +from random import randint + +import modal + +from axolotl.cli.cloud.base import Cloud + + +def run_cmd(cmd: str, run_folder: str, volumes=None): + """Run a command inside a folder, with Modal Volume reloading before and commit on success.""" + # Ensure volumes contain latest files. + if volumes: + for _, vol in volumes.items(): + vol.reload() + + # modal workaround so it doesn't use the automounted axolotl + new_env = copy.deepcopy(os.environ) + if "PYTHONPATH" in new_env: + del new_env["PYTHONPATH"] + + # Propagate errors from subprocess. + if exit_code := subprocess.call( # nosec B603 + cmd.split(), cwd=run_folder, env=new_env + ): + exit(exit_code) # pylint: disable=consider-using-sys-exit + + # Commit writes to volume. + if volumes: + for _, vol in volumes.items(): + vol.commit() + + +class ModalCloud(Cloud): + """ + Modal Cloud implementation. + """ + + def __init__(self, config, app=None): + self.config = config + if not app: + app = modal.App() + self.app = app + + self.volumes = {} + if config.volumes: + for volume_config in config.volumes: + _, mount, vol = self.create_volume(volume_config) + self.volumes[mount] = (vol, volume_config) + + def get_env(self): + res = { + "HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets", + "HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub", + } + + for key in self.config.get("env", []): + if isinstance(key, str): + if val := os.environ.get(key, ""): + res[key] = val + elif isinstance(key, dict): + (key_, val) = list(key.items())[0] + res[key_] = val + return res + + def get_image(self): + docker_tag = "main-py3.11-cu124-2.5.1" + if self.config.docker_tag: + docker_tag = self.config.docker_tag + docker_image = f"axolotlai/axolotl:{docker_tag}" + + # grab the sha256 hash from docker hub for this image+tag + # this ensures that we always get the latest image for this tag, even if it's already cached + try: + manifest = subprocess.check_output( # nosec B602 + f"docker manifest inspect {docker_image}", + shell=True, + ).decode("utf-8") + sha256_hash = json.loads(manifest)["manifests"][0]["digest"] + except subprocess.CalledProcessError: + sha256_hash = None + + # create the image + if sha256_hash: + image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}") + else: + image = modal.Image.from_registry(docker_image) + + dockerfile_commands = [] + if self.config.dockerfile_commands: + dockerfile_commands.extend(self.config.dockerfile_commands) + + # branch + if self.config.branch: + dockerfile_commands.extend( + [ + # Random id for cache busting of branch commits + f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311 + f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}", + ] + ) + + if dockerfile_commands: + image = image.dockerfile_commands(dockerfile_commands) + + if env := self.get_env(): + image = image.env(env) + + image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3") + + return image + + def get_secrets(self): + res = [] + if self.config.secrets: + for key in self.config.get("secrets", []): + # pylint: disable=duplicate-code + if isinstance(key, str): + if val := os.environ.get(key, ""): + res.append(modal.Secret.from_dict({key: val})) + elif isinstance(key, dict): + (key_, val) = list(key.items())[0] + res.append(modal.Secret.from_dict({key_: val})) + return res + + def create_volume(self, volume_config): + name = volume_config.name + mount = volume_config.mount + return name, mount, modal.Volume.from_name(name, create_if_missing=True) + + def get_ephemeral_disk_size(self): + return 1000 * 525 # 1 TiB + + def get_preprocess_timeout(self): + if self.config.timeout_preprocess: + return int(self.config.timeout_preprocess) + return 60 * 60 * 3 # 3 hours + + def get_preprocess_memory(self): + memory = 128 # default to 128GiB + if self.config.memory: + memory = int(self.config.memory) + if self.config.memory_preprocess: + memory = int(self.config.memory_preprocess) + return 1024 * memory + + def get_preprocess_env(self): + return self.app.function( + image=self.get_image(), + volumes={k: v[0] for k, v in self.volumes.items()}, + cpu=8.0, + ephemeral_disk=self.get_ephemeral_disk_size(), + memory=self.get_preprocess_memory(), + timeout=self.get_preprocess_timeout(), + secrets=self.get_secrets(), + ) + + def preprocess(self, config_yaml: str, *args, **kwargs): + modal_fn = self.get_preprocess_env()(_preprocess) + with modal.enable_output(): + with self.app.run(detach=True): + modal_fn.remote( + config_yaml, + volumes={k: v[0] for k, v in self.volumes.items()}, + *args, + **kwargs, + ) + + def get_train_timeout(self): + if self.config.timeout: + return int(self.config.timeout) + return 60 * 60 * 24 # 24 hours + + def get_train_gpu(self): # pylint: disable=too-many-return-statements + count = self.config.gpu_count or 1 + family = self.config.gpu.lower() or "l40s" + + if family == "l40s": + return modal.gpu.L40S(count=count) + if family in ["a100", "a100-40gb"]: + return modal.gpu.A100(count=count, size="40GB") + if family == "a100-80gb": + return modal.gpu.A100(count=count, size="80GB") + if family in ["a10", "a10g"]: + return modal.gpu.A10G(count=count) + if family == "h100": + return modal.gpu.H100(count=count) + if family == "t4": + return modal.gpu.T4(count=count) + if family == "l4": + return modal.gpu.L4(count=count) + raise ValueError(f"Unsupported GPU family: {family}") + + def get_train_memory(self): + memory = 128 # default to 128GiB + if self.config.memory: + memory = int(self.config.memory) + return 1024 * memory + + def get_train_env(self): + return self.app.function( + image=self.get_image(), + volumes={k: v[0] for k, v in self.volumes.items()}, + cpu=16.0, + gpu=self.get_train_gpu(), + memory=self.get_train_memory(), + timeout=self.get_train_timeout(), + secrets=self.get_secrets(), + ) + + def train(self, config_yaml: str, accelerate: bool = True): + modal_fn = self.get_train_env()(_train) + with modal.enable_output(): + with self.app.run(detach=True): + modal_fn.remote( + config_yaml, + accelerate=accelerate, + volumes={k: v[0] for k, v in self.volumes.items()}, + ) + + def lm_eval(self, config_yaml: str): + modal_fn = self.get_train_env()(_lm_eval) + with modal.enable_output(): + with self.app.run(detach=True): + if self.config.get("spawn", False): + modal_fn_exec = modal_fn.spawn + else: + modal_fn_exec = modal_fn.remote + modal_fn_exec( + config_yaml, + volumes={k: v[0] for k, v in self.volumes.items()}, + ) + + +def _preprocess(config_yaml: str, volumes=None): + Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True) + with open( + "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" + ) as f_out: + f_out.write(config_yaml) + run_folder = "/workspace/artifacts/axolotl" + run_cmd( + "axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8", + run_folder, + volumes, + ) + + +def _train(config_yaml: str, accelerate: bool = True, volumes=None): + with open( + "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" + ) as f_out: + f_out.write(config_yaml) + run_folder = "/workspace/artifacts/axolotl" + if accelerate: + accelerate_args = "--accelerate" + else: + accelerate_args = "--no-accelerate" + run_cmd( + f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml", + run_folder, + volumes, + ) + + +def _lm_eval(config_yaml: str, volumes=None): + with open( + "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" + ) as f_out: + f_out.write(config_yaml) + run_folder = "/workspace/artifacts/axolotl" + run_cmd( + "axolotl lm-eval /workspace/artifacts/axolotl/config.yaml", + run_folder, + volumes, + ) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 801fbc80d..e8551511e 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -15,6 +15,7 @@ from axolotl.cli.utils import ( fetch_from_github, filter_none_kwargs, ) +from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -27,21 +28,28 @@ def cli(): @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def preprocess(config: str, **kwargs) -> None: +def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: """ Preprocess datasets before training. Args: config: Path to `axolotl` config YAML file. + cloud: Path to a cloud accelerator configuration file. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - from axolotl.cli.preprocess import do_cli + if cloud: + from axolotl.cli.cloud import do_cli_preprocess - do_cli(config=config, **kwargs) + do_cli_preprocess(cloud_config=cloud, config=config) + else: + from axolotl.cli.preprocess import do_cli + + do_cli(config=config, **kwargs) @cli.command() @@ -51,47 +59,56 @@ def preprocess(config: str, **kwargs) -> None: default=True, help="Use accelerate launch for multi-GPU training", ) +@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def train(config: str, accelerate: bool, **kwargs) -> None: +def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None: """ Train or fine-tune a model. Args: config: Path to `axolotl` config YAML file. accelerate: Whether to use `accelerate` launcher. + cloud: Path to a cloud accelerator configuration file kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() + from axolotl.cli.cloud import do_cli_train if "use_ray" in kwargs and kwargs["use_ray"]: accelerate = False if accelerate: - accelerate_args = [] - if "main_process_port" in kwargs: - main_process_port = kwargs.pop("main_process_port", None) - accelerate_args.append("--main_process_port") - accelerate_args.append(str(main_process_port)) - if "num_processes" in kwargs: - num_processes = kwargs.pop("num_processes", None) - accelerate_args.append("--num-processes") - accelerate_args.append(str(num_processes)) + if cloud: + do_cli_train(cloud_config=cloud, config=config, accelerate=True) + else: + accelerate_args = [] + if "main_process_port" in kwargs: + main_process_port = kwargs.pop("main_process_port", None) + accelerate_args.append("--main_process_port") + accelerate_args.append(str(main_process_port)) + if "num_processes" in kwargs: + num_processes = kwargs.pop("num_processes", None) + accelerate_args.append("--num-processes") + accelerate_args.append(str(num_processes)) - base_cmd = ["accelerate", "launch"] - base_cmd.extend(accelerate_args) - base_cmd.extend(["-m", "axolotl.cli.train"]) - if config: - base_cmd.append(config) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 + base_cmd = ["accelerate", "launch"] + base_cmd.extend(accelerate_args) + base_cmd.extend(["-m", "axolotl.cli.train"]) + if config: + base_cmd.append(config) + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 else: - from axolotl.cli.train import do_cli + if cloud: + do_cli_train(cloud_config=cloud, config=config, accelerate=False) + else: + from axolotl.cli.train import do_cli - do_cli(config=config, **kwargs) + do_cli(config=config, **kwargs) @cli.command() @@ -210,7 +227,6 @@ def merge_lora(config: str, **kwargs) -> None: Args: config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ @@ -237,6 +253,9 @@ def fetch(directory: str, dest: Optional[str]) -> None: fetch_from_github(f"{directory}/", dest) +cli.add_command(lm_eval) + + def main(): cli() diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index f1daa2000..0cbc8a49d 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -2,9 +2,9 @@ Module for the Plugin for LM Eval Harness """ import subprocess # nosec -from datetime import datetime from axolotl.integrations.base import BasePlugin +from axolotl.integrations.lm_eval.cli import build_lm_eval_command from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 @@ -18,25 +18,20 @@ class LMEvalPlugin(BasePlugin): return "axolotl.integrations.lm_eval.LMEvalArgs" def post_train_unload(self, cfg): - tasks = ",".join(cfg.lm_eval_tasks) - fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" - dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" - output_path = cfg.output_dir - output_path += "" if cfg.output_dir.endswith("/") else "/" - output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") - subprocess.run( # nosec - [ - "lm_eval", - "--model", - "hf", - "--model_args", - f"pretrained={cfg.output_dir}{fa2}{dtype}", - "--tasks", - tasks, - "--batch_size", - str(cfg.lm_eval_batch_size), - "--output_path", - output_path, - ], - check=True, - ) + if cfg.lm_eval_post_train: + # pylint: disable=duplicate-code + for lm_eval_args in build_lm_eval_command( + cfg.lm_eval_tasks, + bfloat16=cfg.bfloat16 or cfg.bf16, + flash_attention=cfg.flash_attention, + output_dir=cfg.output_dir, + batch_size=cfg.lm_eval_batch_size, + wandb_project=cfg.wandb_project, + wandb_entity=cfg.wandb_entity, + wandb_name=cfg.wandb_name, + model=cfg.lm_eval_model or cfg.hub_model_id, + ): + subprocess.run( # nosec + lm_eval_args, + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py index f58e6a6e3..721f560e3 100644 --- a/src/axolotl/integrations/lm_eval/args.py +++ b/src/axolotl/integrations/lm_eval/args.py @@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel): lm_eval_tasks: List[str] = [] lm_eval_batch_size: Optional[int] = 8 + lm_eval_post_train: Optional[bool] = True + lm_eval_model: Optional[str] = None diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py new file mode 100644 index 000000000..4a9bbafe6 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -0,0 +1,119 @@ +""" +axolotl CLI for running lm_eval tasks +""" +import subprocess # nosec +from collections import defaultdict +from datetime import datetime +from typing import Optional + +import click +import yaml + +from axolotl.utils.dict import DictDefault + + +def build_lm_eval_command( + tasks: list[str], + bfloat16=True, + flash_attention=False, + output_dir="./", + batch_size=8, + wandb_project=None, + wandb_entity=None, + wandb_name=None, + model=None, + revision=None, + apply_chat_template=None, + fewshot_as_multiturn=None, +): + tasks_by_num_fewshot: dict[str, list] = defaultdict(list) + if isinstance(tasks, str): + tasks = [tasks] + for task in tasks: + num_fewshot = "-1" + task_parts = task.split(":") + task_name = task_parts[0] + if len(task_parts) == 2: + task_name, num_fewshot = task_parts + tasks_by_num_fewshot[str(num_fewshot)].append(task_name) + + for num_fewshot, tasks_list in tasks_by_num_fewshot.items(): + tasks_str = ",".join(tasks_list) + num_fewshot_val = num_fewshot if num_fewshot != "-1" else None + pretrained = "pretrained=" + pretrained += model if model else output_dir + fa2 = ",attn_implementation=flash_attention_2" if flash_attention else "" + dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16" + revision = f",revision={revision}" if revision else "" + output_path = output_dir + output_path += "" if output_dir.endswith("/") else "/" + output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") + lm_eval_args = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"{pretrained}{fa2}{dtype}{revision}", + "--tasks", + tasks_str, + "--batch_size", + str(batch_size), + "--output_path", + output_path, + ] + wandb_args = [] + if wandb_project: + wandb_args.append(f"project={wandb_project}") + if wandb_entity: + wandb_args.append(f"entity={wandb_entity}") + if wandb_name: + wandb_args.append(f"name={wandb_name}") + if wandb_args: + lm_eval_args.append("--wandb_args") + lm_eval_args.append(",".join(wandb_args)) + if apply_chat_template: + lm_eval_args.append("--apply_chat_template") + if num_fewshot_val: + lm_eval_args.append("--num_fewshot") + lm_eval_args.append(str(num_fewshot_val)) + if apply_chat_template and fewshot_as_multiturn: + lm_eval_args.append("--fewshot_as_multiturn") + + yield lm_eval_args + + +@click.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) +def lm_eval(config: str, cloud: Optional[str] = None): + """ + use lm eval to evaluate a trained language model + """ + + if cloud: + from axolotl.cli.cloud import do_cli_lm_eval + + do_cli_lm_eval(cloud_config=cloud, config=config) + else: + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + + # pylint: disable=duplicate-code + for lm_eval_args in build_lm_eval_command( + cfg.lm_eval_tasks, + bfloat16=cfg.bfloat16 or cfg.bf16, + flash_attention=cfg.flash_attention, + output_dir=cfg.output_dir, + batch_size=cfg.lm_eval_batch_size, + wandb_project=cfg.wandb_project, + wandb_entity=cfg.wandb_entity, + wandb_name=cfg.wandb_name, + model=cfg.lm_eval_model or cfg.hub_model_id, + revision=cfg.revision, + apply_chat_template=cfg.apply_chat_template, + fewshot_as_multiturn=cfg.fewshot_as_multiturn, + ): + subprocess.run( # nosec + lm_eval_args, + check=True, + )