Files
axolotl/src/axolotl/cli/main.py
Wing Lian 6f15da4cac make it easier for agents to discover docs (#3579) [skip ci]
* make it easier for agents to discover docs

* fixup pr comments
2026-04-06 10:00:55 -07:00

461 lines
15 KiB
Python

"""Click CLI definitions for various axolotl commands."""
import os
import subprocess # nosec B404
from typing import Literal, Optional
import click
from dotenv import load_dotenv
import axolotl
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
generate_config_files,
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_misc_env, set_pytorch_cuda_alloc_conf
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
LAUNCHER_COMMAND_MAPPING = {
"accelerate": ["accelerate", "launch"],
"torchrun": ["torchrun"],
}
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
set_pytorch_cuda_alloc_conf()
set_misc_env()
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
cloud: Path to a cloud accelerator configuration file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
do_cli_preprocess(cloud_config=cloud, config=config)
else:
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def train(
ctx: click.Context,
config: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
**kwargs,
):
"""
Train or fine-tune a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher
# Process each configuration
for cfg_file, is_group in generate_config_files(config, sweep):
try:
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
finally:
# Only delete temp files, not the original config
if cfg_file != config:
os.unlink(cfg_file)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU evaluation",
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
"""
Evaluate a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.evaluate"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.evaluate import do_cli
do_cli(config=config, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
"""
Run inference with a trained model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.inference"]
)
if config:
base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for weight merging",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def merge_sharded_fsdp_weights(
ctx: click.Context, config: str, launcher: str, **kwargs
):
"""
Merge sharded FSDP model weights.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.merge_sharded_fsdp_weights import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs):
"""
Merge trained LoRA adapters into a base model.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument(
"directory", type=click.Choice(["examples", "deepspeed_configs", "docs"])
)
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
- docs: Full documentation (Quarto markdown files)
Args:
directory: One of `examples`, `deepspeed_configs`, `docs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(VllmServeCliArgs)
@filter_none_kwargs
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
from axolotl.cli.vllm_serve import do_vllm_serve
do_vllm_serve(config, cli_args)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize
do_quantize(config, cli_args)
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str):
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
@cli.command("agent-docs")
@click.argument("topic", required=False, default=None)
@click.option("--list", "list_topics", is_flag=True, help="List available topics")
def agent_docs(topic: Optional[str], list_topics: bool):
"""Show agent-optimized documentation.
Prints reference docs designed for AI coding agents.
These docs are bundled with the package — no network access needed.
\b
Examples:
axolotl agent-docs # overview (start here)
axolotl agent-docs grpo # GRPO reference
axolotl agent-docs sft # SFT reference
axolotl agent-docs --list # list all topics
"""
from axolotl.cli.agent_docs import get_doc, list_topics as _list_topics
if list_topics:
for name, title in _list_topics().items():
click.echo(f" {name:25s} {title}")
return
if topic is None:
topic = "overview"
try:
click.echo(get_doc(topic))
except FileNotFoundError as exc:
raise click.BadParameter(str(exc)) from exc
@cli.command("config-schema")
@click.option(
"--format",
"output_format",
type=click.Choice(["json", "yaml"]),
default="json",
help="Output format (default: json)",
)
@click.option("--field", help="Show schema for a specific field only")
def config_schema(output_format: str, field: Optional[str]):
"""Dump the full config JSON schema.
Useful for AI agents and tooling to discover all available config options,
their types, defaults, and descriptions.
\b
Examples:
axolotl config-schema # full JSON schema
axolotl config-schema --format yaml # YAML format
axolotl config-schema --field adapter # single field
"""
import json
try:
schema = AxolotlInputConfig.model_json_schema()
except (TypeError, ValueError, AttributeError) as exc:
# Fallback: dump field names, types, and defaults when full schema
# generation fails (e.g. torch.dtype not JSON-serializable)
LOG.warning(
"Full JSON schema generation failed, using simplified fallback: %s", exc
)
fields = {}
for name, field_info in AxolotlInputConfig.model_fields.items():
entry = {}
if field_info.description:
entry["description"] = field_info.description
if field_info.default is not None:
try:
json.dumps(field_info.default)
entry["default"] = field_info.default
except (TypeError, ValueError):
entry["default"] = str(field_info.default)
annotation = field_info.annotation
if annotation is not None:
entry["type"] = str(annotation)
fields[name] = entry
schema = {
"properties": fields,
"_note": "simplified schema (full generation failed)",
}
if field:
props = schema.get("properties", {})
if field not in props:
# Try case-insensitive match
matches = [k for k in props if k.lower() == field.lower()]
if matches:
field = matches[0]
else:
raise click.BadParameter(
f"Unknown field: {field!r}. "
f"Omit --field to dump the full schema, "
f"or pipe to jq: axolotl config-schema | jq '.properties | keys'"
)
schema = {field: props[field]}
if output_format == "yaml":
import yaml # pylint: disable=import-outside-toplevel
click.echo(yaml.dump(schema, default_flow_style=False, sort_keys=False))
else:
click.echo(json.dumps(schema, indent=2))
cli.add_command(lm_eval)
def main():
cli()
if __name__ == "__main__":
main()