From bb1cae1a20dd10bce644319d6fc26b1d51c1d666 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 30 Jul 2025 15:46:56 -0400 Subject: [PATCH] CLI: add --launcher option, support launcher args, cleanup, refactor (#2924) * add --launcher option; explicit True/False bool args; small cleanup * refactor * add torchrun, accelerate cli args * add rdzv arg default + tests * update _quarto * coderabbit * fix * we can't set rdvz_id independently across nodes * coderabbit * fix tests --- _quarto.yml | 12 +- docs/cli.qmd | 26 +- docs/multi-node.qmd | 14 +- src/axolotl/cli/args.py | 2 - src/axolotl/cli/cloud/__init__.py | 27 +- src/axolotl/cli/cloud/base.py | 10 +- src/axolotl/cli/cloud/modal_.py | 41 ++- src/axolotl/cli/config.py | 13 +- src/axolotl/cli/delinearize_llama4.py | 2 - src/axolotl/cli/evaluate.py | 6 - src/axolotl/cli/inference.py | 2 - src/axolotl/cli/main.py | 222 ++++++------ src/axolotl/cli/merge_lora.py | 2 - src/axolotl/cli/merge_sharded_fsdp_weights.py | 2 - src/axolotl/cli/preprocess.py | 2 - src/axolotl/cli/train.py | 6 - src/axolotl/cli/utils.py | 330 ------------------ src/axolotl/cli/utils/__init__.py | 23 ++ src/axolotl/cli/utils/args.py | 120 +++++++ src/axolotl/cli/utils/fetch.py | 142 ++++++++ src/axolotl/cli/utils/load.py | 52 +++ src/axolotl/cli/{ => utils}/sweeps.py | 0 src/axolotl/cli/utils/train.py | 188 ++++++++++ tests/cli/test_cli_base.py | 26 +- tests/cli/test_cli_evaluate.py | 114 +++++- tests/cli/test_cli_inference.py | 121 ++++++- tests/cli/test_cli_interface.py | 11 +- .../test_cli_merge_sharded_fsdp_weights.py | 94 ++++- tests/cli/test_cli_sweeps.py | 2 +- tests/cli/test_cli_train.py | 189 +++++++++- tests/cli/test_utils.py | 157 +++++++++ 31 files changed, 1417 insertions(+), 541 deletions(-) delete mode 100644 src/axolotl/cli/utils.py create mode 100644 src/axolotl/cli/utils/__init__.py create mode 100644 src/axolotl/cli/utils/args.py create mode 100644 src/axolotl/cli/utils/fetch.py create mode 100644 src/axolotl/cli/utils/load.py rename src/axolotl/cli/{ => utils}/sweeps.py (100%) create mode 100644 src/axolotl/cli/utils/train.py diff --git a/_quarto.yml b/_quarto.yml index dab1ee363..250596d52 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -35,18 +35,24 @@ quartodoc: - cli.train - cli.evaluate - cli.args + - cli.art - cli.checks - cli.config + - cli.delinearize_llama4 - cli.inference - cli.merge_lora - cli.merge_sharded_fsdp_weights - cli.preprocess - - cli.sweeps - - cli.utils + - cli.quantize - cli.vllm_serve - cli.cloud.base - cli.cloud.modal_ - - cli.quantize + - cli.utils + - cli.utils.args + - cli.utils.fetch + - cli.utils.load + - cli.utils.sweeps + - cli.utils.train - title: Trainers desc: Training implementations contents: diff --git a/docs/cli.qmd b/docs/cli.qmd index f6f9b3481..d9f26dbf8 100644 --- a/docs/cli.qmd +++ b/docs/cli.qmd @@ -23,6 +23,20 @@ axolotl [config.yml] [options] The config file can be local or a URL to a raw YAML file. +### Launcher Arguments + +For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator: + +```bash +# Pass torchrun arguments +axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1 + +# Pass accelerate arguments +axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4 +``` + +Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.). + ## Command Reference ### fetch @@ -80,7 +94,11 @@ axolotl train config.yml \ --num-epochs 3 # Training without accelerate -axolotl train config.yml --no-accelerate +axolotl train config.yml --launcher python + +# Pass launcher-specific arguments using -- separator +axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1 +axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml # Resume training from checkpoint axolotl train config.yml --resume-from-checkpoint path/to/checkpoint @@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets. ```bash # Basic evaluation axolotl evaluate config.yml + +# Evaluation with launcher arguments +axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2 ``` ### lm-eval @@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml # Train on cloud axolotl train config.yml --cloud cloud_config.yml -# Train without accelerate on cloud -axolotl train config.yml --cloud cloud_config.yml --no-accelerate - # Run lm-eval on cloud axolotl lm-eval config.yml --cloud cloud_config.yml ``` diff --git a/docs/multi-node.qmd b/docs/multi-node.qmd index 56d015462..16196a2d7 100644 --- a/docs/multi-node.qmd +++ b/docs/multi-node.qmd @@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152 Run the following on each node: +### Option 1: New Axolotl CLI with launcher args (Recommended) + +```bash +axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" +``` + +### Option 2: Direct torchrun (Legacy) + ```bash torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml ``` -Please make sure to substitute the placeholder variables. +Please make sure to substitute the placeholder variables: - `num_nodes`: Number of nodes (containing GPUs) - `gpu_per_node`: Number of gpus per node @@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables. - `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400) - `rdzv_id`: A unique job ID that is used by the job across nodes. -::: {.callout-note} -You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood -::: +The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features. More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html) diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index e8571a900..31d854d41 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -30,8 +30,6 @@ class TrainerCliArgs: debug_num_examples: int = field(default=0) prompter: Optional[str] = field(default=None) shard: bool = field(default=False) - main_process_port: Optional[int] = field(default=None) - num_processes: Optional[int] = field(default=None) @dataclass diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index 5cdce29dd..bf12ab8cb 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms """ from pathlib import Path -from typing import Union +from typing import Literal import yaml @@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud from axolotl.utils.dict import DictDefault -def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault: +def load_cloud_cfg(cloud_config: Path | str) -> DictDefault: """Load and validate cloud configuration.""" # Load cloud configuration. with open(cloud_config, encoding="utf-8") as file: @@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault: def do_cli_preprocess( - cloud_config: Union[Path, str], - config: Union[Path, str], + cloud_config: Path | str, + config: Path | str, ) -> None: cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) @@ -31,9 +31,10 @@ def do_cli_preprocess( def do_cli_train( - cloud_config: Union[Path, str], - config: Union[Path, str], - accelerate: bool = True, + cloud_config: Path | str, + config: Path | str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, cwd=None, **kwargs, ) -> None: @@ -44,12 +45,18 @@ def do_cli_train( local_dirs = {} if cwd and not Path(cwd).joinpath("src", "axolotl").exists(): local_dirs = {"/workspace/mounts": cwd} - cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs) + cloud.train( + config_yaml, + launcher=launcher, + launcher_args=launcher_args, + local_dirs=local_dirs, + **kwargs, + ) def do_cli_lm_eval( - cloud_config: Union[Path, str], - config: Union[Path, str], + cloud_config: Path | str, + config: Path | str, ) -> None: cloud_cfg = load_cloud_cfg(cloud_config) cloud = ModalCloud(cloud_cfg) diff --git a/src/axolotl/cli/cloud/base.py b/src/axolotl/cli/cloud/base.py index eba8be49a..c498e8691 100644 --- a/src/axolotl/cli/cloud/base.py +++ b/src/axolotl/cli/cloud/base.py @@ -3,6 +3,7 @@ base class for cloud platforms from cli """ from abc import ABC, abstractmethod +from typing import Literal class Cloud(ABC): @@ -15,5 +16,12 @@ class Cloud(ABC): pass @abstractmethod - def train(self, config_yaml: str, accelerate: bool = True) -> str: + def train( + self, + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, + **kwargs, + ): pass diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 83cdd7b72..240c6d894 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -8,7 +8,7 @@ import os import subprocess # nosec B404 from pathlib import Path from random import randint -from typing import Optional +from typing import Literal import modal @@ -230,8 +230,9 @@ class ModalCloud(Cloud): def train( self, config_yaml: str, - accelerate: bool = True, - local_dirs: Optional[dict[str, str]] = None, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + local_dirs: dict[str, str] | None = None, **kwargs, ): modal_fn = self.get_train_env(local_dirs)(_train) @@ -239,7 +240,8 @@ class ModalCloud(Cloud): with self.app.run(detach=True): modal_fn.remote( config_yaml, - accelerate=accelerate, + launcher=launcher, + launcher_args=launcher_args, volumes={k: v[0] for k, v in self.volumes.items()}, **kwargs, ) @@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None): ) -def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs): +def _train( + config_yaml: str, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + launcher_args: list[str] | None = None, + volumes=None, + **kwargs, # pylint: disable=unused-argument +): Path("/workspace/mounts").mkdir(parents=True, exist_ok=True) with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out: f_out.write(config_yaml) run_folder = "/workspace/mounts" - if accelerate: - accelerate_args = "--accelerate" + + launcher_args = launcher_args or [] + + # Build the base command + if launcher == "accelerate": + launcher_arg = "--launcher accelerate" + elif launcher == "torchrun": + launcher_arg = "--launcher torchrun" else: - accelerate_args = "--no-accelerate" - num_processes_args = "" - if num_processes := kwargs.pop("num_processes", None): - num_processes_args = f"--num-processes {num_processes}" + launcher_arg = "--launcher python" + + # Build launcher args string + launcher_args_str = "" + if launcher_args: + launcher_args_str = "-- " + " ".join(launcher_args) + run_cmd( - f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml", + f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(), run_folder, volumes, ) diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index cb0eece7f..ae9f1f9c4 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -197,14 +197,13 @@ def load_cfg( # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value cfg_keys = cfg.keys() - for k, _ in kwargs.items(): - # if not strict, allow writing to cfg even if it's not in the yml already - if k in cfg_keys or not cfg.strict: - # handle booleans - if isinstance(cfg[k], bool): - cfg[k] = bool(kwargs[k]) + for key, value in kwargs.items(): + # If not strict, allow writing to cfg even if it's not in the yml already + if key in cfg_keys or not cfg.strict: + if isinstance(cfg[key], bool): + cfg[key] = bool(value) else: - cfg[k] = kwargs[k] + cfg[key] = value try: device_props = torch.cuda.get_device_properties("cuda") diff --git a/src/axolotl/cli/delinearize_llama4.py b/src/axolotl/cli/delinearize_llama4.py index c92bae930..90227fccd 100644 --- a/src/axolotl/cli/delinearize_llama4.py +++ b/src/axolotl/cli/delinearize_llama4.py @@ -9,7 +9,6 @@ from typing import Generator, Union import fire import torch from accelerate import init_empty_weights -from dotenv import load_dotenv from transformers import AutoProcessor @@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8a847a649..9dd3b0083 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Union import fire -from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser from axolotl.cli.args import TrainerCliArgs @@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.evaluate import evaluate -from axolotl.utils import patch_optimized_env from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger @@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: cfg: Dictionary mapping `axolotl` config keys to values. cli_args: CLI arguments. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() - # pylint: disable=duplicate-code check_accelerate_default_config() if int(os.getenv("LOCAL_RANK", "0")) == 0: @@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 10132cd6f..83b567b64 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -9,7 +9,6 @@ from typing import Union import fire import torch import transformers -from dotenv import load_dotenv from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.cli.args import InferenceCliArgs @@ -268,5 +267,4 @@ def do_cli( if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 69c1425ac..c41acc40b 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -4,12 +4,9 @@ import os import subprocess # nosec B404 -import tempfile -from pathlib import Path -from typing import Optional +from typing import Literal, Optional import click -import yaml from dotenv import load_dotenv import axolotl @@ -21,13 +18,14 @@ from axolotl.cli.args import ( VllmServeCliArgs, ) from axolotl.cli.art import print_axolotl_text_art -from axolotl.cli.sweeps import generate_sweep_configs from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, build_command, fetch_from_github, filter_none_kwargs, + generate_config_files, + launch_training, ) from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import patch_optimized_env @@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig LOG = get_logger(__name__) +LAUNCHER_COMMAND_MAPPING = { + "accelerate": ["accelerate", "launch"], + "torchrun": ["torchrun"], +} + @click.group() @click.version_option(version=axolotl.__version__, prog_name="axolotl") def cli(): """Axolotl CLI - Train and fine-tune large language models""" print_axolotl_text_art() + load_dotenv() + patch_optimized_env() @cli.command() @@ -50,7 +55,7 @@ def cli(): @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: +def preprocess(config: str, cloud: Optional[str] = None, **kwargs): """ Preprocess datasets before training. @@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - patch_optimized_env() if cloud: from axolotl.cli.cloud import do_cli_preprocess @@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: do_cli(config=config, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for multi-GPU training", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU training", ) @click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) @click.option( @@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None: @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs +@click.pass_context def train( + ctx: click.Context, config: str, - accelerate: bool, - cloud: Optional[str] = None, - sweep: Optional[str] = None, + launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", + cloud: str | None = None, + sweep: str | None = None, **kwargs, -) -> None: +): """ Train or fine-tune a model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python"). cloud: Path to a cloud accelerator configuration file sweep: Path to YAML config for sweeping hyperparameters. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] - if "use_ray" in kwargs and kwargs["use_ray"]: - accelerate = False - if sweep: - # load the sweep configuration yaml file - with open(sweep, "r", encoding="utf-8") as fin: - sweep_config: dict[str, list] = yaml.safe_load(fin) - with open(config, "r", encoding="utf-8") as fin: - base_config: dict[str, list] = yaml.safe_load(fin) + # Handle Ray launcher override + _launcher = None if kwargs.get("use_ray") else launcher - # generate all possible configurations - permutations = generate_sweep_configs(base_config, sweep_config) - - def iter_configs(): - for perm in permutations: - # open temp directory for temporary configurations - with tempfile.TemporaryDirectory() as temp_dir: - with open( - Path(temp_dir) / "config.yaml", "w", encoding="utf-8" - ) as fout: - yaml.dump(perm, fout) - yield str(Path(temp_dir) / "config.yaml") - - else: - - def iter_configs(): - yield config - - for cfg_file in iter_configs(): - # handle errors from subprocess so we can continue rest of sweeps + # Process each configuration + for cfg_file in generate_config_files(config, sweep): try: - if accelerate: - if cloud: - from axolotl.cli.cloud import do_cli_train - - cwd = os.getcwd() - do_cli_train( - cloud_config=cloud, - config=config, - accelerate=True, - cwd=cwd, - **kwargs, - ) - else: - accelerate_args = [] - if "main_process_port" in kwargs: - main_process_port = kwargs.pop("main_process_port", None) - accelerate_args.append("--main_process_port") - accelerate_args.append(str(main_process_port)) - if "num_processes" in kwargs: - num_processes = kwargs.pop("num_processes", None) - accelerate_args.append("--num_processes") - accelerate_args.append(str(num_processes)) - - base_cmd = ["accelerate", "launch"] - base_cmd.extend(accelerate_args) - base_cmd.extend(["-m", "axolotl.cli.train"]) - if cfg_file: - base_cmd.append(cfg_file) - cmd = build_command(base_cmd, kwargs) - subprocess.run(cmd, check=True) # nosec B603 - else: - if cloud: - from axolotl.cli.cloud import do_cli_train - - do_cli_train( - cloud_config=cloud, config=config, accelerate=False, **kwargs - ) - else: - from axolotl.cli.train import do_cli - - do_cli(config=cfg_file, **kwargs) + launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args) except subprocess.CalledProcessError as exc: LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") if not sweep: raise exc + finally: + # Only delete temp files, not the original config + if cfg_file != config: + os.unlink(cfg_file) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for multi-GPU training", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU evaluation", ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def evaluate(config: str, accelerate: bool, **kwargs) -> None: +@click.pass_context +def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs): """ Evaluate a model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python"). kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.evaluate"] + ) if config: base_cmd.append(config) cmd = build_command(base_cmd, kwargs) @@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None: do_cli(config=config, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=False, - help="Use accelerate launch for multi-GPU inference", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for multi-GPU inference", ) @click.option("--gradio", is_flag=True, help="Launch Gradio interface") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: +@click.pass_context +def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs): """ Run inference with a trained model. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python"). gradio: Whether to use Gradio browser interface or command line for inference. kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.inference"] + ) if config: base_cmd.append(config) if gradio: @@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: do_cli(config=config, gradio=gradio, **kwargs) -@cli.command() +@cli.command( + context_settings={"ignore_unknown_options": True, "allow_extra_args": True} +) @click.argument("config", type=click.Path(exists=True, path_type=str)) @click.option( - "--accelerate/--no-accelerate", - default=True, - help="Use accelerate launch for weight merging", + "--launcher", + type=click.Choice(["accelerate", "torchrun", "python"]), + default="accelerate", + help="Launcher to use for weight merging", ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: +@click.pass_context +def merge_sharded_fsdp_weights( + ctx: click.Context, config: str, launcher: str, **kwargs +): """ Merge sharded FSDP model weights. Args: + ctx: Click context for extra args. config: Path to `axolotl` config YAML file. - accelerate: Whether to use `accelerate` launcher. + launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python"). kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - if accelerate: - base_cmd = [ - "accelerate", - "launch", - "-m", - "axolotl.cli.merge_sharded_fsdp_weights", - ] + # Extract launcher args from extra args (after --) + launcher_args = ctx.args if ctx.args else [] + + if launcher in LAUNCHER_COMMAND_MAPPING: + base_cmd = ( + LAUNCHER_COMMAND_MAPPING[launcher] + + launcher_args + + ["-m", "axolotl.cli.merge_sharded_fsdp_weights"] + ) if config: base_cmd.append(config) cmd = build_command(base_cmd, kwargs) @@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs -def merge_lora(config: str, **kwargs) -> None: +def merge_lora(config: str, **kwargs): """ Merge trained LoRA adapters into a base model. @@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None: @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") -def fetch(directory: str, dest: Optional[str]) -> None: +def fetch(directory: str, dest: Optional[str]): """ Fetch example configs or other resources. @@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs): @cli.command() @click.argument("model", type=click.Path(exists=True, path_type=str)) @click.argument("output", type=click.Path(exists=False, path_type=str)) -def delinearize_llama4(model: str, output: str) -> None: +def delinearize_llama4(model: str, output: str): from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4 do_delinearize_llama4(model, output) @@ -365,5 +348,4 @@ def main(): if __name__ == "__main__": - load_dotenv() main() diff --git a/src/axolotl/cli/merge_lora.py b/src/axolotl/cli/merge_lora.py index d639b3aee..422593a48 100644 --- a/src/axolotl/cli/merge_lora.py +++ b/src/axolotl/cli/merge_lora.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Union import fire -from dotenv import load_dotenv from axolotl.cli.config import load_cfg from axolotl.cli.utils import load_model_and_tokenizer @@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/merge_sharded_fsdp_weights.py b/src/axolotl/cli/merge_sharded_fsdp_weights.py index b0880ce21..c08d30ec8 100644 --- a/src/axolotl/cli/merge_sharded_fsdp_weights.py +++ b/src/axolotl/cli/merge_sharded_fsdp_weights.py @@ -17,7 +17,6 @@ from accelerate.utils import ( WEIGHTS_NAME, is_torch_version, ) -from dotenv import load_dotenv from huggingface_hub import split_torch_state_dict_into_shards from safetensors.torch import save_file as safe_save_file from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner @@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 595eb8aac..5d692c315 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -9,7 +9,6 @@ import fire import transformers from accelerate import init_empty_weights from colorama import Fore -from dotenv import load_dotenv from transformers import AutoModelForCausalLM from axolotl.cli.args import PreprocessCliArgs @@ -109,5 +108,4 @@ def do_cli( if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index d0cf8455b..7f0b0bdd2 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -7,7 +7,6 @@ from typing import Union import fire from accelerate import Accelerator -from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser from axolotl.cli.args import TrainerCliArgs @@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.integrations.base import PluginManager from axolotl.train import train -from axolotl.utils import patch_optimized_env from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.dict import DictDefault @@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): cfg: Dictionary mapping `axolotl` config keys to values. cli_args: Training-specific CLI arguments. """ - # Enable expandable segments for cuda allocation to improve VRAM usage - patch_optimized_env() - check_accelerate_default_config() if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token() @@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict): if __name__ == "__main__": - load_dotenv() fire.Fire(do_cli) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py deleted file mode 100644 index d28795361..000000000 --- a/src/axolotl/cli/utils.py +++ /dev/null @@ -1,330 +0,0 @@ -"""Utility methods for axolotl CLI.""" - -import concurrent.futures -import dataclasses -import hashlib -import json -from functools import wraps -from pathlib import Path -from types import NoneType -from typing import Any, Callable, Type, Union, get_args, get_origin - -import click -import requests -from pydantic import BaseModel -from transformers import ( - PreTrainedModel, - PreTrainedTokenizer, - PreTrainedTokenizerFast, - ProcessorMixin, -) - -from axolotl.loaders import load_processor, load_tokenizer -from axolotl.loaders.model import ModelLoader -from axolotl.utils.dict import DictDefault -from axolotl.utils.logging import get_logger - -LOG = get_logger(__name__) - - -def strip_optional_type(field_type: type | str | None): - """ - Extracts the non-`None` type from an `Optional` / `Union` type. - - Args: - field_type: Type of field for Axolotl CLI command. - - Returns: - If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise - returns the input type unchanged. - """ - if get_origin(field_type) is Union and type(None) in get_args(field_type): - field_type = next( - t for t in get_args(field_type) if not isinstance(t, NoneType) - ) - - return field_type - - -def filter_none_kwargs(func: Callable) -> Callable: - """ - Wraps function to remove `None`-valued `kwargs`. - - Args: - func: Function to wrap. - - Returns: - Wrapped function. - """ - - @wraps(func) - def wrapper(*args, **kwargs) -> Callable: - """Filters out `None`-valued `kwargs`.""" - filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} - - return func(*args, **filtered_kwargs) - - return wrapper - - -def add_options_from_dataclass(config_class: Type[Any]) -> Callable: - """ - Create Click options from the fields of a dataclass. - - Args: - config_class: Dataclass with fields to parse from the CLI. - - Returns: - Function decorator for Axolotl CLI command. - """ - - def decorator(function: Callable) -> Callable: - # Process dataclass fields in reverse order for correct option ordering - for field in reversed(dataclasses.fields(config_class)): - field_type = strip_optional_type(field.type) - - if field_type == bool: - field_name = field.name.replace("_", "-") - option_name = f"--{field_name}/--no-{field_name}" - function = click.option( - option_name, - default=field.default, - help=field.metadata.get("description"), - )(function) - else: - option_name = f"--{field.name.replace('_', '-')}" - function = click.option( - option_name, - type=field_type, - default=field.default, - help=field.metadata.get("description"), - )(function) - - return function - - return decorator - - -def add_options_from_config(config_class: Type[BaseModel]) -> Callable: - """ - Create Click options from the fields of a Pydantic model. - - Args: - config_class: PyDantic model with fields to parse from the CLI - - Returns: - Function decorator for Axolotl CLI command. - """ - - def decorator(function: Callable) -> Callable: - # Process model fields in reverse order for correct option ordering - for name, field in reversed(config_class.model_fields.items()): - field_type = strip_optional_type(field.annotation) - - if field_type == bool: - field_name = name.replace("_", "-") - option_name = f"--{field_name}/--no-{field_name}" - function = click.option( - option_name, default=None, help=field.description - )(function) - else: - option_name = f"--{name.replace('_', '-')}" - function = click.option( - option_name, default=None, help=field.description - )(function) - - return function - - return decorator - - -def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: - """ - Build command list from base command and options. - - Args: - base_cmd: Command without options. - options: Options to parse and append to base command. - - Returns: - List of strings giving shell command. - """ - cmd = base_cmd.copy() - - for key, value in options.items(): - if value is None: - continue - - key = key.replace("_", "-") - - if isinstance(value, bool): - if value: - cmd.append(f"--{key}") - else: - cmd.extend([f"--{key}", str(value)]) - - return cmd - - -def download_file( - file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str -) -> tuple[str, str]: - """ - Download a single file and return its processing status. - - Args: - file_info: Tuple of (file_path, remote_sha). - raw_base_url: Base URL for raw GitHub content. - dest_path: Local destination directory. - dir_prefix: Directory prefix to filter files. - - Returns: - Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. - """ - file_path, remote_sha = file_info - raw_url = f"{raw_base_url}/{file_path}" - dest_file = dest_path / file_path.split(dir_prefix)[-1] - - # Check if file exists and needs updating - if dest_file.exists(): - with open(dest_file, "rb") as file: - content = file.read() - # Calculate git blob SHA - blob = b"blob " + str(len(content)).encode() + b"\0" + content - local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest() - - if local_sha == remote_sha: - print(f"Skipping {file_path} (unchanged)") - return file_path, "unchanged" - - print(f"Updating {file_path}") - status = "new" - else: - print(f"Downloading {file_path}") - status = "new" - - # Create directories if needed - dest_file.parent.mkdir(parents=True, exist_ok=True) - - # Download and save file - try: - response = requests.get(raw_url, timeout=30) - response.raise_for_status() - - with open(dest_file, "wb") as file: - file.write(response.content) - - return file_path, status - except (requests.RequestException, IOError) as request_error: - print(f"Error downloading {file_path}: {str(request_error)}") - return file_path, "error" - - -def fetch_from_github( - dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 -) -> None: - """ - Sync files from a specific directory in the GitHub repository. - Only downloads files that don't exist locally or have changed. - - Args: - dir_prefix: Directory prefix to filter files (e.g., 'examples/', - 'deepspeed_configs/'). - dest_dir: Local destination directory. - max_workers: Maximum number of concurrent downloads. - """ - api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" - raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" - - # Get repository tree with timeout - response = requests.get(api_url, timeout=30) - response.raise_for_status() - tree = json.loads(response.text) - - # Filter for files and get their SHA - files = { - item["path"]: item["sha"] - for item in tree["tree"] - if item["type"] == "blob" and item["path"].startswith(dir_prefix) - } - - if not files: - raise click.ClickException(f"No files found in {dir_prefix}") - - # Default destination directory is the last part of dir_prefix - default_dest = Path(dir_prefix.rstrip("/")) - dest_path = Path(dest_dir) if dest_dir else default_dest - - # Keep track of processed files for summary - files_processed: dict[str, list[str]] = { - "new": [], - "updated": [], - "unchanged": [], - "error": [], - } - - # Process files in parallel using ThreadPoolExecutor - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit( - download_file, - (file_path, remote_sha), - raw_base_url, - dest_path, - dir_prefix, - ): file_path - for file_path, remote_sha in files.items() - } - - # Process completed tasks as they finish - for future in concurrent.futures.as_completed(future_to_file): - file_path = future_to_file[future] - try: - file_path, status = future.result() - files_processed[status].append(file_path) - except (requests.RequestException, IOError) as request_error: - print(f"Error processing {file_path}: {str(request_error)}") - files_processed["error"].append(file_path) - - # Log summary - LOG.info("\nSync Summary:") - LOG.info(f"New files: {len(files_processed['new'])}") - LOG.info(f"Updated files: {len(files_processed['updated'])}") - LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") - if files_processed["error"]: - LOG.info(f"Failed files: {len(files_processed['error'])}") - - -def load_model_and_tokenizer( - *, - cfg: DictDefault, - inference: bool = False, -) -> tuple[ - PreTrainedModel, - PreTrainedTokenizer | PreTrainedTokenizerFast | Any, - ProcessorMixin | None, -]: - """ - Helper function for loading a model, tokenizer, and processor specified in the given `axolotl` - config. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - inference: Boolean denoting inference mode. - - Returns: - Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin). - """ - LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") - tokenizer = load_tokenizer(cfg) - - LOG.info("loading model...") - model_loader = ModelLoader(cfg, tokenizer, inference=inference) - model, _ = model_loader.load() - - processor = None - if cfg.is_multimodal: - LOG.info("loading processor...") - processor = load_processor(cfg, tokenizer) - - return model, tokenizer, processor diff --git a/src/axolotl/cli/utils/__init__.py b/src/axolotl/cli/utils/__init__.py new file mode 100644 index 000000000..583130339 --- /dev/null +++ b/src/axolotl/cli/utils/__init__.py @@ -0,0 +1,23 @@ +"""Init for axolotl.cli.utils module.""" + +from .args import ( + add_options_from_config, + add_options_from_dataclass, + filter_none_kwargs, +) +from .fetch import fetch_from_github +from .load import load_model_and_tokenizer +from .sweeps import generate_sweep_configs +from .train import build_command, generate_config_files, launch_training + +__all__ = [ + "filter_none_kwargs", + "add_options_from_dataclass", + "add_options_from_config", + "build_command", + "generate_config_files", + "generate_sweep_configs", + "load_model_and_tokenizer", + "launch_training", + "fetch_from_github", +] diff --git a/src/axolotl/cli/utils/args.py b/src/axolotl/cli/utils/args.py new file mode 100644 index 000000000..3aea1a378 --- /dev/null +++ b/src/axolotl/cli/utils/args.py @@ -0,0 +1,120 @@ +"""Utilities for axolotl CLI args.""" + +import dataclasses +from functools import wraps +from types import NoneType +from typing import Any, Callable, Type, Union, get_args, get_origin + +import click +from pydantic import BaseModel + + +def _strip_optional_type(field_type: type | str | None): + """ + Extracts the non-`None` type from an `Optional` / `Union` type. + + Args: + field_type: Type of field for Axolotl CLI command. + + Returns: + If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise + returns the input type unchanged. + """ + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + return field_type + + +def filter_none_kwargs(func: Callable) -> Callable: + """ + Wraps function to remove `None`-valued `kwargs`. + + Args: + func: Function to wrap. + + Returns: + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs) -> Callable: + """Filters out `None`-valued `kwargs`.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return func(*args, **filtered_kwargs) + + return wrapper + + +def add_options_from_dataclass(config_class: Type[Any]) -> Callable: + """ + Create Click options from the fields of a dataclass. + + Args: + config_class: Dataclass with fields to parse from the CLI. + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: + # Process dataclass fields in reverse order for correct option ordering + for field in reversed(dataclasses.fields(config_class)): + field_type = _strip_optional_type(field.type) + + if field_type == bool: + field_name = field.name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, + default=field.default, + help=field.metadata.get("description"), + )(function) + else: + option_name = f"--{field.name.replace('_', '-')}" + function = click.option( + option_name, + type=field_type, + default=field.default, + help=field.metadata.get("description"), + )(function) + + return function + + return decorator + + +def add_options_from_config(config_class: Type[BaseModel]) -> Callable: + """ + Create Click options from the fields of a Pydantic model. + + Args: + config_class: PyDantic model with fields to parse from the CLI + + Returns: + Function decorator for Axolotl CLI command. + """ + + def decorator(function: Callable) -> Callable: + # Process model fields in reverse order for correct option ordering + for name, field in reversed(config_class.model_fields.items()): + field_type = _strip_optional_type(field.annotation) + + if field_type == bool: + field_name = name.replace("_", "-") + option_name = f"--{field_name}/--no-{field_name}" + function = click.option( + option_name, default=None, help=field.description + )(function) + else: + option_name = f"--{name.replace('_', '-')}" + function = click.option( + option_name, default=None, help=field.description + )(function) + + return function + + return decorator diff --git a/src/axolotl/cli/utils/fetch.py b/src/axolotl/cli/utils/fetch.py new file mode 100644 index 000000000..441b7f6f7 --- /dev/null +++ b/src/axolotl/cli/utils/fetch.py @@ -0,0 +1,142 @@ +"""Utilities for axolotl fetch CLI command.""" + +import concurrent.futures +import hashlib +import json +from pathlib import Path + +import click +import requests + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def _download_file( + file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str +) -> tuple[str, str]: + """ + Download a single file and return its processing status. + + Args: + file_info: Tuple of (file_path, remote_sha). + raw_base_url: Base URL for raw GitHub content. + dest_path: Local destination directory. + dir_prefix: Directory prefix to filter files. + + Returns: + Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. + """ + file_path, remote_sha = file_info + raw_url = f"{raw_base_url}/{file_path}" + dest_file = dest_path / file_path.split(dir_prefix)[-1] + + # Check if file exists and needs updating + if dest_file.exists(): + with open(dest_file, "rb") as file: + content = file.read() + # Calculate git blob SHA + blob = b"blob " + str(len(content)).encode() + b"\0" + content + local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest() + + if local_sha == remote_sha: + print(f"Skipping {file_path} (unchanged)") + return file_path, "unchanged" + + print(f"Updating {file_path}") + status = "updated" + else: + print(f"Downloading {file_path}") + status = "new" + + # Create directories if needed + dest_file.parent.mkdir(parents=True, exist_ok=True) + + # Download and save file + try: + response = requests.get(raw_url, timeout=30) + response.raise_for_status() + + with open(dest_file, "wb") as file: + file.write(response.content) + + return file_path, status + except (requests.RequestException, IOError) as request_error: + print(f"Error downloading {file_path}: {str(request_error)}") + return file_path, "error" + + +def fetch_from_github( + dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 +) -> None: + """ + Sync files from a specific directory in the GitHub repository. + Only downloads files that don't exist locally or have changed. + + Args: + dir_prefix: Directory prefix to filter files (e.g., 'examples/', + 'deepspeed_configs/'). + dest_dir: Local destination directory. + max_workers: Maximum number of concurrent downloads. + """ + api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" + raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" + + # Get repository tree with timeout + response = requests.get(api_url, timeout=30) + response.raise_for_status() + tree = json.loads(response.text) + + # Filter for files and get their SHA + files = { + item["path"]: item["sha"] + for item in tree["tree"] + if item["type"] == "blob" and item["path"].startswith(dir_prefix) + } + + if not files: + raise click.ClickException(f"No files found in {dir_prefix}") + + # Default destination directory is the last part of dir_prefix + default_dest = Path(dir_prefix.rstrip("/")) + dest_path = Path(dest_dir) if dest_dir else default_dest + + # Keep track of processed files for summary + files_processed: dict[str, list[str]] = { + "new": [], + "updated": [], + "unchanged": [], + "error": [], + } + + # Process files in parallel using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_file = { + executor.submit( + _download_file, + (file_path, remote_sha), + raw_base_url, + dest_path, + dir_prefix, + ): file_path + for file_path, remote_sha in files.items() + } + + # Process completed tasks as they finish + for future in concurrent.futures.as_completed(future_to_file): + file_path = future_to_file[future] + try: + file_path, status = future.result() + files_processed[status].append(file_path) + except (requests.RequestException, IOError) as request_error: + print(f"Error processing {file_path}: {str(request_error)}") + files_processed["error"].append(file_path) + + # Log summary + LOG.info("\nSync Summary:") + LOG.info(f"New files: {len(files_processed['new'])}") + LOG.info(f"Updated files: {len(files_processed['updated'])}") + LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") + if files_processed["error"]: + LOG.info(f"Failed files: {len(files_processed['error'])}") diff --git a/src/axolotl/cli/utils/load.py b/src/axolotl/cli/utils/load.py new file mode 100644 index 000000000..610a81306 --- /dev/null +++ b/src/axolotl/cli/utils/load.py @@ -0,0 +1,52 @@ +"""Utilities for model, tokenizer, etc. loading.""" + +from typing import Any + +from transformers import ( + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, +) + +from axolotl.loaders import load_processor, load_tokenizer +from axolotl.loaders.model import ModelLoader +from axolotl.utils.dict import DictDefault +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +def load_model_and_tokenizer( + *, + cfg: DictDefault, + inference: bool = False, +) -> tuple[ + PreTrainedModel, + PreTrainedTokenizer | PreTrainedTokenizerFast | Any, + ProcessorMixin | None, +]: + """ + Helper function for loading a model, tokenizer, and processor specified in the + given `axolotl` config. + + Args: + cfg: Dictionary mapping `axolotl` config keys to values. + inference: Boolean denoting inference mode. + + Returns: + Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin). + """ + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) + + LOG.info("loading model...") + model_loader = ModelLoader(cfg, tokenizer, inference=inference) + model, _ = model_loader.load() + + processor = None + if cfg.is_multimodal: + LOG.info("loading processor...") + processor = load_processor(cfg, tokenizer) + + return model, tokenizer, processor diff --git a/src/axolotl/cli/sweeps.py b/src/axolotl/cli/utils/sweeps.py similarity index 100% rename from src/axolotl/cli/sweeps.py rename to src/axolotl/cli/utils/sweeps.py diff --git a/src/axolotl/cli/utils/train.py b/src/axolotl/cli/utils/train.py new file mode 100644 index 000000000..61d05e52b --- /dev/null +++ b/src/axolotl/cli/utils/train.py @@ -0,0 +1,188 @@ +"""Utilities for axolotl train CLI command.""" + +import os +import subprocess # nosec +import tempfile +from typing import Any, Iterator, Literal + +import yaml + +from axolotl.cli.utils.sweeps import generate_sweep_configs + + +def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]: + """ + Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing. + + Args: + launcher_args: List of launcher arguments + + Returns: + Updated launcher args with defaults added if needed + """ + args = launcher_args.copy() + + # Check if rdzv_endpoint is present + has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args) + + if has_rdzv_endpoint: + # Check if rdzv_backend is already provided + has_rdzv_backend = any("--rdzv_backend" in arg for arg in args) + if not has_rdzv_backend: + args.extend(["--rdzv_backend", "c10d"]) + + # Check if rdzv_id is already provided + has_rdzv_id = any("--rdzv_id" in arg for arg in args) + if not has_rdzv_id: + import uuid + + args.extend(["--rdzv_id", str(uuid.uuid4())[:8]]) + + return args + + +def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: + """ + Build command list from base command and options. + + Args: + base_cmd: Command without options. + options: Options to parse and append to base command. + + Returns: + List of strings giving shell command. + """ + cmd = base_cmd.copy() + + for key, value in options.items(): + if value is None: + continue + + key = key.replace("_", "-") + cmd.append(f"--{key}={value}") + + return cmd + + +def generate_config_files(config: str, sweep: str | None) -> Iterator[str]: + """Generate list of configuration files to process.""" + if not sweep: + yield config + return + + # Load sweep and base configurations + with open(sweep, "r", encoding="utf-8") as fin: + sweep_config: dict[str, list] = yaml.safe_load(fin) + with open(config, "r", encoding="utf-8") as fin: + base_config: dict[str, list] = yaml.safe_load(fin) + + # Generate all possible configurations + permutations = generate_sweep_configs(base_config, sweep_config) + for permutation in permutations: + # pylint: disable=consider-using-with + temp_file = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + delete=False, + encoding="utf-8", + ) + yaml.dump(permutation, temp_file) + temp_file.close() + yield temp_file.name + + +def launch_training( + cfg_file: str, + launcher: Literal["accelerate", "torchrun", "python"] | None, + cloud: str | None, + kwargs: dict, + launcher_args: list[str] | None = None, +) -> None: + """Execute training with the given configuration.""" + launcher_args = launcher_args or [] + + if cloud: + _launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args) + elif launcher: + if launcher == "accelerate": + _launch_accelerate_training(cfg_file, kwargs, launcher_args) + elif launcher == "torchrun": + _launch_torchrun_training(cfg_file, kwargs, launcher_args) + elif launcher == "python": + _launch_python_training(cfg_file, kwargs) + + +def _launch_cloud_training( + cloud: str, + cfg_file: str, + launcher: Literal["accelerate", "torchrun", "python"] | None, + kwargs: dict, + launcher_args: list[str] | None = None, +) -> None: + """Execute training via cloud launcher.""" + from axolotl.cli.cloud import do_cli_train + + launcher_args = launcher_args or [] + cwd = os.getcwd() if launcher else None + + do_cli_train( + cloud_config=cloud, + config=cfg_file, + launcher=launcher or "accelerate", + launcher_args=launcher_args, + cwd=cwd, + **kwargs, + ) + + +def _launch_accelerate_training( + cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None +) -> None: + """Execute training via accelerate launcher.""" + launcher_args = launcher_args or [] + internal_launcher_args = [] + + # Extract launcher-specific arguments from kwargs (legacy support) + if "main_process_port" in kwargs: + main_process_port = kwargs.pop("main_process_port") + internal_launcher_args.extend(["--main_process_port", str(main_process_port)]) + + if "num_processes" in kwargs: + num_processes = kwargs.pop("num_processes") + internal_launcher_args.extend(["--num_processes", str(num_processes)]) + + # Combine internal args with user-provided launcher args + all_launcher_args = internal_launcher_args + launcher_args + + base_cmd = ( + ["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"] + ) + if cfg_file: + base_cmd.append(cfg_file) + + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + + +def _launch_torchrun_training( + cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None +) -> None: + """Execute training via torchrun launcher.""" + launcher_args = launcher_args or [] + + # Add default RDZV arguments if rdzv_endpoint is set + launcher_args = _add_default_rdzv_args(launcher_args) + + base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"] + if cfg_file: + base_cmd.append(cfg_file) + + cmd = build_command(base_cmd, kwargs) + subprocess.run(cmd, check=True) # nosec B603 + + +def _launch_python_training(cfg_file: str, kwargs: dict) -> None: + """Execute training via python launcher.""" + from axolotl.cli.train import do_cli + + do_cli(config=cfg_file, **kwargs) diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py index 6dbae045f..4b880d44a 100644 --- a/tests/cli/test_cli_base.py +++ b/tests/cli/test_cli_base.py @@ -17,16 +17,23 @@ class BaseCliTest: command: Command to test (train/evaluate) """ # Test missing config file - result = cli_runner.invoke(cli, [command, "--no-accelerate"]) + result = cli_runner.invoke(cli, [command, "--launcher", "python"]) assert result.exit_code != 0 # Test non-existent config file - result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"]) + result = cli_runner.invoke( + cli, [command, "nonexistent.yml", "--launcher", "python"] + ) assert result.exit_code != 0 assert "Error: Invalid value for 'CONFIG'" in result.output def _test_basic_execution( - self, cli_runner, tmp_path: Path, valid_test_config: str, command: str + self, + cli_runner, + tmp_path: Path, + valid_test_config: str, + command: str, + train: bool = True, ): """Test basic execution with accelerate. @@ -35,6 +42,7 @@ class BaseCliTest: tmp_path: Temporary path fixture valid_test_config: Valid config fixture command: Command to test (train/evaluate) + train: Whether to test training (default) or evaluation """ config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) @@ -43,15 +51,21 @@ class BaseCliTest: result = cli_runner.invoke(cli, [command, str(config_path)]) assert mock.called - assert mock.call_args.args[0] == [ + + expected = [ "accelerate", "launch", "-m", f"axolotl.cli.{command}", str(config_path), - "--debug-num-examples", - "0", + "--debug=False", + "--debug-text-only=False", + "--debug-num-examples=0", ] + if train: + expected.append("--shard=False") + + assert mock.call_args.args[0] == expected assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 diff --git a/tests/cli/test_cli_evaluate.py b/tests/cli/test_cli_evaluate.py index d8eb41467..a191bf957 100644 --- a/tests/cli/test_cli_evaluate.py +++ b/tests/cli/test_cli_evaluate.py @@ -1,5 +1,7 @@ """Tests for evaluate CLI command.""" +# pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest): def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config): """Test basic successful execution""" - self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate") + self._test_basic_execution( + cli_runner, tmp_path, valid_test_config, "evaluate", train=False + ) def test_evaluate_basic_execution_no_accelerate( self, cli_runner, tmp_path, valid_test_config @@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest): config_path = tmp_path / "config.yml" config_path.write_text(valid_test_config) + # pylint: disable=duplicate-code with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate: result = cli_runner.invoke( cli, [ "evaluate", str(config_path), - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest): "2", "--sequence-len", "128", - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest): cfg = mock_evaluate.call_args[0][0] assert cfg.micro_batch_size == 2 assert cfg.sequence_len == 128 + + def test_evaluate_with_launcher_args_torchrun( + self, cli_runner, tmp_path, valid_test_config + ): + """Test evaluate with torchrun launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.evaluate" in called_cmd + + def test_evaluate_with_launcher_args_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test evaluate with accelerate launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.evaluate" in called_cmd + + def test_evaluate_backward_compatibility_no_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test that existing evaluate commands work without launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "evaluate", + str(config_path), + "--launcher", + "accelerate", + "--micro-batch-size", + "2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert ( + len(launcher_section) == 0 + ) # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index b8effa3d2..3394c189d 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,5 +1,7 @@ """pytest tests for axolotl CLI inference command.""" +# pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path): with patch("axolotl.cli.inference.do_inference") as mock: result = cli_runner.invoke( cli, - ["inference", str(config_path), "--no-accelerate"], + ["inference", str(config_path), "--launcher", "python"], catch_exceptions=False, ) @@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path): with patch("axolotl.cli.inference.do_inference_gradio") as mock: result = cli_runner.invoke( cli, - ["inference", str(config_path), "--no-accelerate", "--gradio"], + ["inference", str(config_path), "--launcher", "python", "--gradio"], catch_exceptions=False, ) assert mock.called assert result.exit_code == 0 + + +def test_inference_with_launcher_args_torchrun(cli_runner, config_path): + """Test inference with torchrun launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_with_launcher_args_accelerate(cli_runner, config_path): + """Test inference with accelerate launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_gradio_with_launcher_args(cli_runner, config_path): + """Test inference with gradio and launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + "--gradio", + "--", + "--num_processes=2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify both gradio flag and launcher args are present + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--num_processes=2" in called_cmd + assert "--gradio" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.inference" in called_cmd + + +def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path): + """Test that existing inference commands work without launcher args""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "inference", + str(config_path), + "--launcher", + "accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index 8b5fec17f..ebd91ea60 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -18,11 +18,10 @@ def test_build_command(): assert result == [ "accelerate", "launch", - "--learning-rate", - "0.0001", - "--batch-size", - "8", - "--debug", + "--learning-rate=0.0001", + "--batch-size=8", + "--debug=True", + "--use-fp16=False", ] @@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner): ], ) assert result.exit_code != 0 - assert "No such option" in result.output + assert "does not exist" in result.output def test_required_config_argument(cli_runner): diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index ec96b4ed4..4f6a973ea 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): """Test merge_sharded_fsdp_weights command without accelerate""" with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: result = cli_runner.invoke( - cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"] + cli, + ["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"], ) assert mock.called assert mock.call_args.kwargs["config"] == str(config_path) assert result.exit_code == 0 + + +def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun( + cli_runner, config_path +): + """Test merge-sharded-fsdp-weights with torchrun launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd + + +def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate( + cli_runner, config_path +): + """Test merge-sharded-fsdp-weights with accelerate launcher arguments""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd + + +def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args( + cli_runner, config_path +): + """Test that existing merge-sharded-fsdp-weights commands work without launcher args""" + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "merge-sharded-fsdp-weights", + str(config_path), + "--launcher", + "accelerate", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m' diff --git a/tests/cli/test_cli_sweeps.py b/tests/cli/test_cli_sweeps.py index 40b360717..1b14f5aca 100644 --- a/tests/cli/test_cli_sweeps.py +++ b/tests/cli/test_cli_sweeps.py @@ -2,7 +2,7 @@ unit tests for generating sweep configurations """ -from axolotl.cli.main import generate_sweep_configs +from axolotl.cli.utils import generate_sweep_configs def test_generate_sweep_configs_no_pairs(): diff --git a/tests/cli/test_cli_train.py b/tests/cli/test_cli_train.py index 473913599..9b266f057 100644 --- a/tests/cli/test_cli_train.py +++ b/tests/cli/test_cli_train.py @@ -1,5 +1,7 @@ """Tests for train CLI command.""" +# pylint: disable=duplicate-code + from unittest.mock import MagicMock, patch from axolotl.cli.main import cli @@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest): def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config): """Test basic successful execution""" - self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train") + self._test_basic_execution( + cli_runner, tmp_path, valid_test_config, "train", train=True + ) def test_train_basic_execution_no_accelerate( self, cli_runner, tmp_path, valid_test_config @@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest): [ "train", str(config_path), - "--no-accelerate", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest): [ "train", str(config_path), - "--learning-rate", - "1e-4", - "--micro-batch-size", - "2", - "--no-accelerate", + "--learning-rate=1e-4", + "--micro-batch-size=2", + "--launcher", + "python", ], catch_exceptions=False, ) @@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest): cfg = mock_train.call_args[1]["cfg"] assert cfg["learning_rate"] == 1e-4 assert cfg["micro_batch_size"] == 2 + + def test_train_with_launcher_args_torchrun( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with torchrun launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=2", + "--nnodes=1", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to torchrun + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "torchrun" + assert "--nproc_per_node=2" in called_cmd + assert "--nnodes=1" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.train" in called_cmd + + def test_train_with_launcher_args_accelerate( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with accelerate launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "accelerate", + "--", + "--config_file=accelerate_config.yml", + "--num_processes=4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify launcher args are passed to accelerate + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + assert "--config_file=accelerate_config.yml" in called_cmd + assert "--num_processes=4" in called_cmd + assert "-m" in called_cmd + assert "axolotl.cli.train" in called_cmd + + def test_train_backward_compatibility_no_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test that existing train commands work without launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "accelerate", + "--learning-rate", + "1e-4", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + # Verify no launcher args contamination + called_cmd = mock_subprocess.call_args.args[0] + assert called_cmd[0] == "accelerate" + assert called_cmd[1] == "launch" + # Should not contain any extra launcher args + launcher_section = called_cmd[2 : called_cmd.index("-m")] + assert ( + len(launcher_section) == 0 + ) # No launcher args between 'launch' and '-m' + + def test_train_mixed_args_with_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with both regular CLI args and launcher args""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch("subprocess.run") as mock_subprocess: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--launcher", + "torchrun", + "--learning-rate", + "2e-4", + "--micro-batch-size", + "4", + "--", + "--nproc_per_node=8", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_subprocess.assert_called_once() + + called_cmd = mock_subprocess.call_args.args[0] + # Verify launcher args + assert "--nproc_per_node=8" in called_cmd + # Verify axolotl args are also present + assert "--learning-rate=2e-4" in called_cmd + assert "--micro-batch-size=4" in called_cmd + + def test_train_cloud_with_launcher_args( + self, cli_runner, tmp_path, valid_test_config + ): + """Test train with cloud and launcher arguments""" + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + cloud_path = tmp_path / "cloud.yml" + cloud_path.write_text("provider: modal\ngpu: a100") + + with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train: + result = cli_runner.invoke( + cli, + [ + "train", + str(config_path), + "--cloud", + str(cloud_path), + "--launcher", + "torchrun", + "--", + "--nproc_per_node=4", + "--nnodes=2", + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + mock_cloud_train.assert_called_once() + + # Verify cloud training was called with launcher args + call_kwargs = mock_cloud_train.call_args.kwargs + assert call_kwargs["launcher"] == "torchrun" + assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"] diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index 2dab5bba9..a3e4e9887 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -72,3 +72,160 @@ def test_fetch_from_github_network_error(): with patch("requests.get", side_effect=requests.RequestException): with pytest.raises(requests.RequestException): fetch_from_github("examples/", None) + + +def assert_launcher_args_in_command( + mock_subprocess_call, + launcher: str, + expected_launcher_args: list[str], + command_module: str, +): + """ + Helper function to verify launcher arguments are properly passed in subprocess calls. + + Args: + mock_subprocess_call: The mock subprocess.run call + launcher: Expected launcher ("accelerate", "torchrun", etc.) + expected_launcher_args: List of expected launcher arguments + command_module: Expected module name (e.g., "axolotl.cli.train") + """ + assert mock_subprocess_call.called, "subprocess.run should have been called" + called_cmd = mock_subprocess_call.call_args.args[0] + + # Verify launcher + assert ( + called_cmd[0] == launcher + ), f"Expected launcher {launcher}, got {called_cmd[0]}" + + # Verify launcher args are present + for arg in expected_launcher_args: + assert ( + arg in called_cmd + ), f"Expected launcher arg '{arg}' not found in command: {called_cmd}" + + # Verify module is present + assert "-m" in called_cmd, "Expected -m flag for module execution" + assert ( + command_module in called_cmd + ), f"Expected module {command_module} not found in command: {called_cmd}" + + +def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str): + """ + Helper function to verify no unwanted launcher arguments are present. + + Args: + mock_subprocess_call: The mock subprocess.run call + launcher: Expected launcher ("accelerate", "torchrun", etc.) + """ + assert mock_subprocess_call.called, "subprocess.run should have been called" + called_cmd = mock_subprocess_call.call_args.args[0] + + if launcher == "accelerate": + # For accelerate, launcher args should be between 'launch' and '-m' + launch_idx = called_cmd.index("launch") + m_idx = called_cmd.index("-m") + launcher_section = called_cmd[launch_idx + 1 : m_idx] + assert ( + len(launcher_section) == 0 + ), f"Unexpected launcher args found: {launcher_section}" + elif launcher == "torchrun": + # For torchrun, launcher args should be between 'torchrun' and '-m' + torchrun_idx = called_cmd.index("torchrun") + m_idx = called_cmd.index("-m") + launcher_section = called_cmd[torchrun_idx + 1 : m_idx] + assert ( + len(launcher_section) == 0 + ), f"Unexpected launcher args found: {launcher_section}" + + +@pytest.fixture +def common_launcher_args(): + """Fixture providing common launcher argument combinations for testing.""" + return { + "torchrun": ["--nproc_per_node=2", "--nnodes=1"], + "accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"], + } + + +def test_add_default_rdzv_args_with_endpoint(): + """Test that default RDZV args are added when rdzv_endpoint is present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"] + result = _add_default_rdzv_args(launcher_args) + + # Should have added rdzv_backend + assert "--rdzv_backend" in result + assert "c10d" in result + + # Original args should still be present + assert "--nnodes=2" in result + assert "--rdzv_endpoint=127.0.0.1:29400" in result + + +def test_add_default_rdzv_args_with_existing_backend(): + """Test that existing rdzv_backend is not overridden.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_backend=static", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add another rdzv_backend + backend_count = sum(1 for arg in result if "--rdzv_backend" in arg) + assert backend_count == 1 + assert "--rdzv_backend=static" in result + + +def test_add_default_rdzv_args_with_existing_id(): + """Test that existing rdzv_id is not overridden.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_id=my_job_123", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add another rdzv_id + id_count = sum(1 for arg in result if "--rdzv_id" in arg) + assert id_count == 1 + assert "--rdzv_id=my_job_123" in result + + # Should still add rdzv_backend + assert "--rdzv_backend" in result + assert "c10d" in result + + +def test_add_default_rdzv_args_without_endpoint(): + """Test that no RDZV args are added when rdzv_endpoint is not present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = ["--nnodes=2", "--nproc_per_node=4"] + result = _add_default_rdzv_args(launcher_args) + + # Should not add any rdzv args + assert "--rdzv_backend" not in result + assert result == launcher_args + + +def test_add_default_rdzv_args_with_all_existing(): + """Test that no defaults are added when all RDZV args are present.""" + from axolotl.cli.utils.train import _add_default_rdzv_args + + launcher_args = [ + "--nnodes=2", + "--rdzv_endpoint=127.0.0.1:29400", + "--rdzv_backend=static", + "--rdzv_id=existing_job", + ] + result = _add_default_rdzv_args(launcher_args) + + # Should not add any additional args + assert len(result) == len(launcher_args) + assert result == launcher_args