CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)

* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
This commit is contained in:
Dan Saunders
2025-07-30 15:46:56 -04:00
committed by GitHub
parent 22810c97b7
commit bb1cae1a20
31 changed files with 1417 additions and 541 deletions

View File

@@ -35,18 +35,24 @@ quartodoc:
- cli.train
- cli.evaluate
- cli.args
- cli.art
- cli.checks
- cli.config
- cli.delinearize_llama4
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.sweeps
- cli.utils
- cli.quantize
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- cli.utils
- cli.utils.args
- cli.utils.fetch
- cli.utils.load
- cli.utils.sweeps
- cli.utils.train
- title: Trainers
desc: Training implementations
contents:

View File

@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Launcher Arguments
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
```bash
# Pass torchrun arguments
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
# Pass accelerate arguments
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
```
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
## Command Reference
### fetch
@@ -80,7 +94,11 @@ axolotl train config.yml \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --no-accelerate
axolotl train config.yml --launcher python
# Pass launcher-specific arguments using -- separator
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
axolotl evaluate config.yml
# Evaluation with launcher arguments
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
```
### lm-eval
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```

View File

@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
Run the following on each node:
### Option 1: New Axolotl CLI with launcher args (Recommended)
```bash
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
```
### Option 2: Direct torchrun (Legacy)
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
Please make sure to substitute the placeholder variables:
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -30,8 +30,6 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=0)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass

View File

@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
from typing import Literal
import yaml
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
@@ -31,9 +31,10 @@ def do_cli_preprocess(
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cloud_config: Path | str,
config: Path | str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
cwd=None,
**kwargs,
) -> None:
@@ -44,12 +45,18 @@ def do_cli_train(
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
cloud.train(
config_yaml,
launcher=launcher,
launcher_args=launcher_args,
local_dirs=local_dirs,
**kwargs,
)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str],
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)

View File

@@ -3,6 +3,7 @@ base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
from typing import Literal
class Cloud(ABC):
@@ -15,5 +16,12 @@ class Cloud(ABC):
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
pass

View File

@@ -8,7 +8,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Optional
from typing import Literal
import modal
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
def train(
self,
config_yaml: str,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
launcher=launcher,
launcher_args=launcher_args,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
def _train(
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs, # pylint: disable=unused-argument
):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
if accelerate:
accelerate_args = "--accelerate"
launcher_args = launcher_args or []
# Build the base command
if launcher == "accelerate":
launcher_arg = "--launcher accelerate"
elif launcher == "torchrun":
launcher_arg = "--launcher torchrun"
else:
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
launcher_arg = "--launcher python"
# Build launcher args string
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
run_cmd(
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
run_folder,
volumes,
)

View File

@@ -197,14 +197,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
for key, value in kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[k] = kwargs[k]
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -9,7 +9,6 @@ from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
@@ -268,5 +267,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -4,12 +4,9 @@
import os
import subprocess # nosec B404
import tempfile
from pathlib import Path
from typing import Optional
from typing import Literal, Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
VllmServeCliArgs,
)
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
generate_config_files,
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
LAUNCHER_COMMAND_MAPPING = {
"accelerate": ["accelerate", "launch"],
"torchrun": ["torchrun"],
}
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
@cli.command()
@@ -50,7 +55,7 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
"""
Preprocess datasets before training.
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
@@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def train(
ctx: click.Context,
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
**kwargs,
) -> None:
):
"""
Train or fine-tune a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
else:
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
# Process each configuration
for cfg_file in generate_config_files(config, sweep):
try:
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
finally:
# Only delete temp files, not the original config
if cfg_file != config:
os.unlink(cfg_file)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU evaluation",
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
"""
Evaluate a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.evaluate"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
do_cli(config=config, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU inference",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
@click.pass_context
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
"""
Run inference with a trained model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.inference"]
)
if config:
base_cmd.append(config)
if gradio:
@@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for weight merging",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@click.pass_context
def merge_sharded_fsdp_weights(
ctx: click.Context, config: str, launcher: str, **kwargs
):
"""
Merge sharded FSDP model weights.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
)
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None:
def merge_lora(config: str, **kwargs):
"""
Merge trained LoRA adapters into a base model.
@@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None:
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]) -> None:
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
@@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str) -> None:
def delinearize_llama4(model: str, output: str):
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
@@ -365,5 +348,4 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -17,7 +17,6 @@ from accelerate.utils import (
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,7 +9,6 @@ import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs
@@ -109,5 +108,4 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -7,7 +7,6 @@ from typing import Union
import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,330 +0,0 @@
"""Utility methods for axolotl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.extend([f"--{key}", str(value)])
return cmd
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "new"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,23 @@
"""Init for axolotl.cli.utils module."""
from .args import (
add_options_from_config,
add_options_from_dataclass,
filter_none_kwargs,
)
from .fetch import fetch_from_github
from .load import load_model_and_tokenizer
from .sweeps import generate_sweep_configs
from .train import build_command, generate_config_files, launch_training
__all__ = [
"filter_none_kwargs",
"add_options_from_dataclass",
"add_options_from_config",
"build_command",
"generate_config_files",
"generate_sweep_configs",
"load_model_and_tokenizer",
"launch_training",
"fetch_from_github",
]

View File

@@ -0,0 +1,120 @@
"""Utilities for axolotl CLI args."""
import dataclasses
from functools import wraps
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
from pydantic import BaseModel
def _strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = _strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name,
default=field.default,
help=field.metadata.get("description"),
)(function)
else:
option_name = f"--{field.name.replace('_', '-')}"
function = click.option(
option_name,
type=field_type,
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
option_name, default=None, help=field.description
)(function)
else:
option_name = f"--{name.replace('_', '-')}"
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator

View File

@@ -0,0 +1,142 @@
"""Utilities for axolotl fetch CLI command."""
import concurrent.futures
import hashlib
import json
from pathlib import Path
import click
import requests
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
dest_file = dest_path / file_path.split(dir_prefix)[-1]
# Check if file exists and needs updating
if dest_file.exists():
with open(dest_file, "rb") as file:
content = file.read()
# Calculate git blob SHA
blob = b"blob " + str(len(content)).encode() + b"\0" + content
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
if local_sha == remote_sha:
print(f"Skipping {file_path} (unchanged)")
return file_path, "unchanged"
print(f"Updating {file_path}")
status = "updated"
else:
print(f"Downloading {file_path}")
status = "new"
# Create directories if needed
dest_file.parent.mkdir(parents=True, exist_ok=True)
# Download and save file
try:
response = requests.get(raw_url, timeout=30)
response.raise_for_status()
with open(dest_file, "wb") as file:
file.write(response.content)
return file_path, status
except (requests.RequestException, IOError) as request_error:
print(f"Error downloading {file_path}: {str(request_error)}")
return file_path, "error"
def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
# Get repository tree with timeout
response = requests.get(api_url, timeout=30)
response.raise_for_status()
tree = json.loads(response.text)
# Filter for files and get their SHA
files = {
item["path"]: item["sha"]
for item in tree["tree"]
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
}
if not files:
raise click.ClickException(f"No files found in {dir_prefix}")
# Default destination directory is the last part of dir_prefix
default_dest = Path(dir_prefix.rstrip("/"))
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
"error": [],
}
# Process files in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(
_download_file,
(file_path, remote_sha),
raw_base_url,
dest_path,
dir_prefix,
): file_path
for file_path, remote_sha in files.items()
}
# Process completed tasks as they finish
for future in concurrent.futures.as_completed(future_to_file):
file_path = future_to_file[future]
try:
file_path, status = future.result()
files_processed[status].append(file_path)
except (requests.RequestException, IOError) as request_error:
print(f"Error processing {file_path}: {str(request_error)}")
files_processed["error"].append(file_path)
# Log summary
LOG.info("\nSync Summary:")
LOG.info(f"New files: {len(files_processed['new'])}")
LOG.info(f"Updated files: {len(files_processed['updated'])}")
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")

View File

@@ -0,0 +1,52 @@
"""Utilities for model, tokenizer, etc. loading."""
from typing import Any
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[
PreTrainedModel,
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
ProcessorMixin | None,
]:
"""
Helper function for loading a model, tokenizer, and processor specified in the
given `axolotl` config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
model, _ = model_loader.load()
processor = None
if cfg.is_multimodal:
LOG.info("loading processor...")
processor = load_processor(cfg, tokenizer)
return model, tokenizer, processor

View File

@@ -0,0 +1,188 @@
"""Utilities for axolotl train CLI command."""
import os
import subprocess # nosec
import tempfile
from typing import Any, Iterator, Literal
import yaml
from axolotl.cli.utils.sweeps import generate_sweep_configs
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
"""
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
Args:
launcher_args: List of launcher arguments
Returns:
Updated launcher args with defaults added if needed
"""
args = launcher_args.copy()
# Check if rdzv_endpoint is present
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
if has_rdzv_endpoint:
# Check if rdzv_backend is already provided
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
if not has_rdzv_backend:
args.extend(["--rdzv_backend", "c10d"])
# Check if rdzv_id is already provided
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
if not has_rdzv_id:
import uuid
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
return args
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
if value is None:
continue
key = key.replace("_", "-")
cmd.append(f"--{key}={value}")
return cmd
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
"""Generate list of configuration files to process."""
if not sweep:
yield config
return
# Load sweep and base configurations
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",
suffix=".yaml",
delete=False,
encoding="utf-8",
)
yaml.dump(permutation, temp_file)
temp_file.close()
yield temp_file.name
def launch_training(
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
cloud: str | None,
kwargs: dict,
launcher_args: list[str] | None = None,
) -> None:
"""Execute training with the given configuration."""
launcher_args = launcher_args or []
if cloud:
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
elif launcher:
if launcher == "accelerate":
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
elif launcher == "torchrun":
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
elif launcher == "python":
_launch_python_training(cfg_file, kwargs)
def _launch_cloud_training(
cloud: str,
cfg_file: str,
launcher: Literal["accelerate", "torchrun", "python"] | None,
kwargs: dict,
launcher_args: list[str] | None = None,
) -> None:
"""Execute training via cloud launcher."""
from axolotl.cli.cloud import do_cli_train
launcher_args = launcher_args or []
cwd = os.getcwd() if launcher else None
do_cli_train(
cloud_config=cloud,
config=cfg_file,
launcher=launcher or "accelerate",
launcher_args=launcher_args,
cwd=cwd,
**kwargs,
)
def _launch_accelerate_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via accelerate launcher."""
launcher_args = launcher_args or []
internal_launcher_args = []
# Extract launcher-specific arguments from kwargs (legacy support)
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port")
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes")
internal_launcher_args.extend(["--num_processes", str(num_processes)])
# Combine internal args with user-provided launcher args
all_launcher_args = internal_launcher_args + launcher_args
base_cmd = (
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
)
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
def _launch_torchrun_training(
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
) -> None:
"""Execute training via torchrun launcher."""
launcher_args = launcher_args or []
# Add default RDZV arguments if rdzv_endpoint is set
launcher_args = _add_default_rdzv_args(launcher_args)
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
"""Execute training via python launcher."""
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)

View File

@@ -17,16 +17,23 @@ class BaseCliTest:
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
result = cli_runner.invoke(
cli, [command, "nonexistent.yml", "--launcher", "python"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
self,
cli_runner,
tmp_path: Path,
valid_test_config: str,
command: str,
train: bool = True,
):
"""Test basic execution with accelerate.
@@ -35,6 +42,7 @@ class BaseCliTest:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
train: Whether to test training (default) or evaluation
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
@@ -43,15 +51,21 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
expected = [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
"--debug=False",
"--debug-text-only=False",
"--debug-num-examples=0",
]
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,5 +1,7 @@
"""Tests for evaluate CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
)
def test_evaluate_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
# pylint: disable=duplicate-code
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
"2",
"--sequence-len",
"128",
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
cfg = mock_evaluate.call_args[0][0]
assert cfg.micro_batch_size == 2
assert cfg.sequence_len == 128
def test_evaluate_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing evaluate commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--micro-batch-size",
"2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'

View File

@@ -1,5 +1,7 @@
"""pytest tests for axolotl CLI inference command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
["inference", str(config_path), "--launcher", "python"],
catch_exceptions=False,
)
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
["inference", str(config_path), "--launcher", "python", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
"""Test inference with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
"""Test inference with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
"""Test inference with gradio and launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--gradio",
"--",
"--num_processes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify both gradio flag and launcher args are present
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--num_processes=2" in called_cmd
assert "--gradio" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
"""Test that existing inference commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -18,11 +18,10 @@ def test_build_command():
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
"--learning-rate=0.0001",
"--batch-size=8",
"--debug=True",
"--use-fp16=False",
]
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
],
)
assert result.exit_code != 0
assert "No such option" in result.output
assert "does not exist" in result.output
def test_required_config_argument(cli_runner):

View File

@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
cli,
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
cli_runner, config_path
):
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -2,7 +2,7 @@
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
from axolotl.cli.utils import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():

View File

@@ -1,5 +1,7 @@
"""Tests for train CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "train", train=True
)
def test_train_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
"--learning-rate=1e-4",
"--micro-batch-size=2",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest):
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2
def test_train_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing train commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--learning-rate",
"1e-4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'
def test_train_mixed_args_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with both regular CLI args and launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--learning-rate",
"2e-4",
"--micro-batch-size",
"4",
"--",
"--nproc_per_node=8",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present
assert "--learning-rate=2e-4" in called_cmd
assert "--micro-batch-size=4" in called_cmd
def test_train_cloud_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with cloud and launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
cloud_path = tmp_path / "cloud.yml"
cloud_path.write_text("provider: modal\ngpu: a100")
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--cloud",
str(cloud_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=4",
"--nnodes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_cloud_train.assert_called_once()
# Verify cloud training was called with launcher args
call_kwargs = mock_cloud_train.call_args.kwargs
assert call_kwargs["launcher"] == "torchrun"
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]

View File

@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)
def assert_launcher_args_in_command(
mock_subprocess_call,
launcher: str,
expected_launcher_args: list[str],
command_module: str,
):
"""
Helper function to verify launcher arguments are properly passed in subprocess calls.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
expected_launcher_args: List of expected launcher arguments
command_module: Expected module name (e.g., "axolotl.cli.train")
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
# Verify launcher
assert (
called_cmd[0] == launcher
), f"Expected launcher {launcher}, got {called_cmd[0]}"
# Verify launcher args are present
for arg in expected_launcher_args:
assert (
arg in called_cmd
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
# Verify module is present
assert "-m" in called_cmd, "Expected -m flag for module execution"
assert (
command_module in called_cmd
), f"Expected module {command_module} not found in command: {called_cmd}"
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
"""
Helper function to verify no unwanted launcher arguments are present.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
if launcher == "accelerate":
# For accelerate, launcher args should be between 'launch' and '-m'
launch_idx = called_cmd.index("launch")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[launch_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
elif launcher == "torchrun":
# For torchrun, launcher args should be between 'torchrun' and '-m'
torchrun_idx = called_cmd.index("torchrun")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
@pytest.fixture
def common_launcher_args():
"""Fixture providing common launcher argument combinations for testing."""
return {
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
}
def test_add_default_rdzv_args_with_endpoint():
"""Test that default RDZV args are added when rdzv_endpoint is present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
result = _add_default_rdzv_args(launcher_args)
# Should have added rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
# Original args should still be present
assert "--nnodes=2" in result
assert "--rdzv_endpoint=127.0.0.1:29400" in result
def test_add_default_rdzv_args_with_existing_backend():
"""Test that existing rdzv_backend is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_backend
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
assert backend_count == 1
assert "--rdzv_backend=static" in result
def test_add_default_rdzv_args_with_existing_id():
"""Test that existing rdzv_id is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_id=my_job_123",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_id
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
assert id_count == 1
assert "--rdzv_id=my_job_123" in result
# Should still add rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
def test_add_default_rdzv_args_without_endpoint():
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
result = _add_default_rdzv_args(launcher_args)
# Should not add any rdzv args
assert "--rdzv_backend" not in result
assert result == launcher_args
def test_add_default_rdzv_args_with_all_existing():
"""Test that no defaults are added when all RDZV args are present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
"--rdzv_id=existing_job",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add any additional args
assert len(result) == len(launcher_args)
assert result == launcher_args