CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)
* add --launcher option; explicit True/False bool args; small cleanup * refactor * add torchrun, accelerate cli args * add rdzv arg default + tests * update _quarto * coderabbit * fix * we can't set rdvz_id independently across nodes * coderabbit * fix tests
This commit is contained in:
12
_quarto.yml
12
_quarto.yml
@@ -35,18 +35,24 @@ quartodoc:
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.art
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.delinearize_llama4
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.quantize
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- cli.utils
|
||||
- cli.utils.args
|
||||
- cli.utils.fetch
|
||||
- cli.utils.load
|
||||
- cli.utils.sweeps
|
||||
- cli.utils.train
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
|
||||
26
docs/cli.qmd
26
docs/cli.qmd
@@ -23,6 +23,20 @@ axolotl <command> [config.yml] [options]
|
||||
|
||||
The config file can be local or a URL to a raw YAML file.
|
||||
|
||||
### Launcher Arguments
|
||||
|
||||
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
|
||||
|
||||
```bash
|
||||
# Pass torchrun arguments
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
|
||||
# Pass accelerate arguments
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
|
||||
```
|
||||
|
||||
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
|
||||
|
||||
## Command Reference
|
||||
|
||||
### fetch
|
||||
@@ -80,7 +94,11 @@ axolotl train config.yml \
|
||||
--num-epochs 3
|
||||
|
||||
# Training without accelerate
|
||||
axolotl train config.yml --no-accelerate
|
||||
axolotl train config.yml --launcher python
|
||||
|
||||
# Pass launcher-specific arguments using -- separator
|
||||
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
|
||||
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
|
||||
|
||||
# Resume training from checkpoint
|
||||
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
|
||||
@@ -175,6 +193,9 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||
```bash
|
||||
# Basic evaluation
|
||||
axolotl evaluate config.yml
|
||||
|
||||
# Evaluation with launcher arguments
|
||||
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
|
||||
```
|
||||
|
||||
### lm-eval
|
||||
@@ -287,9 +308,6 @@ axolotl preprocess config.yml --cloud cloud_config.yml
|
||||
# Train on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml
|
||||
|
||||
# Train without accelerate on cloud
|
||||
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
|
||||
|
||||
# Run lm-eval on cloud
|
||||
axolotl lm-eval config.yml --cloud cloud_config.yml
|
||||
```
|
||||
|
||||
@@ -69,11 +69,19 @@ export NCCL_BUFFSIZE=2097152
|
||||
|
||||
Run the following on each node:
|
||||
|
||||
### Option 1: New Axolotl CLI with launcher args (Recommended)
|
||||
|
||||
```bash
|
||||
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
|
||||
```
|
||||
|
||||
### Option 2: Direct torchrun (Legacy)
|
||||
|
||||
```bash
|
||||
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
Please make sure to substitute the placeholder variables.
|
||||
Please make sure to substitute the placeholder variables:
|
||||
|
||||
- `num_nodes`: Number of nodes (containing GPUs)
|
||||
- `gpu_per_node`: Number of gpus per node
|
||||
@@ -81,8 +89,6 @@ Please make sure to substitute the placeholder variables.
|
||||
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
|
||||
- `rdzv_id`: A unique job ID that is used by the job across nodes.
|
||||
|
||||
::: {.callout-note}
|
||||
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
|
||||
:::
|
||||
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
|
||||
|
||||
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)
|
||||
|
||||
@@ -30,8 +30,6 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
num_processes: Optional[int] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -3,7 +3,7 @@ launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
|
||||
@@ -11,7 +11,7 @@ from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
@@ -20,8 +20,8 @@ def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
@@ -31,9 +31,10 @@ def do_cli_preprocess(
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
accelerate: bool = True,
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
cwd=None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@@ -44,12 +45,18 @@ def do_cli_train(
|
||||
local_dirs = {}
|
||||
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
|
||||
local_dirs = {"/workspace/mounts": cwd}
|
||||
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
|
||||
cloud.train(
|
||||
config_yaml,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
local_dirs=local_dirs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str],
|
||||
cloud_config: Path | str,
|
||||
config: Path | str,
|
||||
) -> None:
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
|
||||
@@ -3,6 +3,7 @@ base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
@@ -15,5 +16,12 @@ class Cloud(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -8,7 +8,7 @@ import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
from typing import Optional
|
||||
from typing import Literal
|
||||
|
||||
import modal
|
||||
|
||||
@@ -230,8 +230,9 @@ class ModalCloud(Cloud):
|
||||
def train(
|
||||
self,
|
||||
config_yaml: str,
|
||||
accelerate: bool = True,
|
||||
local_dirs: Optional[dict[str, str]] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
local_dirs: dict[str, str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
modal_fn = self.get_train_env(local_dirs)(_train)
|
||||
@@ -239,7 +240,8 @@ class ModalCloud(Cloud):
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
launcher=launcher,
|
||||
launcher_args=launcher_args,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
**kwargs,
|
||||
)
|
||||
@@ -270,20 +272,35 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
def _train(
|
||||
config_yaml: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
launcher_args: list[str] | None = None,
|
||||
volumes=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Build the base command
|
||||
if launcher == "accelerate":
|
||||
launcher_arg = "--launcher accelerate"
|
||||
elif launcher == "torchrun":
|
||||
launcher_arg = "--launcher torchrun"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
num_processes_args = ""
|
||||
if num_processes := kwargs.pop("num_processes", None):
|
||||
num_processes_args = f"--num-processes {num_processes}"
|
||||
launcher_arg = "--launcher python"
|
||||
|
||||
# Build launcher args string
|
||||
launcher_args_str = ""
|
||||
if launcher_args:
|
||||
launcher_args_str = "-- " + " ".join(launcher_args)
|
||||
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
|
||||
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
@@ -197,14 +197,13 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
for k, _ in kwargs.items():
|
||||
# if not strict, allow writing to cfg even if it's not in the yml already
|
||||
if k in cfg_keys or not cfg.strict:
|
||||
# handle booleans
|
||||
if isinstance(cfg[k], bool):
|
||||
cfg[k] = bool(kwargs[k])
|
||||
for key, value in kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Generator, Union
|
||||
import fire
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoProcessor
|
||||
|
||||
|
||||
@@ -152,5 +151,4 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -13,7 +12,6 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
@@ -30,9 +28,6 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
@@ -64,5 +59,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Union
|
||||
import fire
|
||||
import torch
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
|
||||
from axolotl.cli.args import InferenceCliArgs
|
||||
@@ -268,5 +267,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -4,12 +4,9 @@
|
||||
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import axolotl
|
||||
@@ -21,13 +18,14 @@ from axolotl.cli.args import (
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.sweeps import generate_sweep_configs
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
generate_config_files,
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
@@ -36,12 +34,19 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
LAUNCHER_COMMAND_MAPPING = {
|
||||
"accelerate": ["accelerate", "launch"],
|
||||
"torchrun": ["torchrun"],
|
||||
}
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,7 +55,7 @@ def cli():
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Preprocess datasets before training.
|
||||
|
||||
@@ -60,7 +65,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
@@ -72,12 +76,15 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
@@ -88,126 +95,81 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@click.pass_context
|
||||
def train(
|
||||
ctx: click.Context,
|
||||
config: str,
|
||||
accelerate: bool,
|
||||
cloud: Optional[str] = None,
|
||||
sweep: Optional[str] = None,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Train or fine-tune a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
|
||||
cloud: Path to a cloud accelerator configuration file
|
||||
sweep: Path to YAML config for sweeping hyperparameters.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
if sweep:
|
||||
# load the sweep configuration yaml file
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
# generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
|
||||
def iter_configs():
|
||||
for perm in permutations:
|
||||
# open temp directory for temporary configurations
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
with open(
|
||||
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
|
||||
) as fout:
|
||||
yaml.dump(perm, fout)
|
||||
yield str(Path(temp_dir) / "config.yaml")
|
||||
|
||||
else:
|
||||
|
||||
def iter_configs():
|
||||
yield config
|
||||
|
||||
for cfg_file in iter_configs():
|
||||
# handle errors from subprocess so we can continue rest of sweeps
|
||||
# Process each configuration
|
||||
for cfg_file in generate_config_files(config, sweep):
|
||||
try:
|
||||
if accelerate:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
cwd = os.getcwd()
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=config,
|
||||
accelerate=True,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
accelerate_args = []
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port", None)
|
||||
accelerate_args.append("--main_process_port")
|
||||
accelerate_args.append(str(main_process_port))
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes", None)
|
||||
accelerate_args.append("--num_processes")
|
||||
accelerate_args.append(str(num_processes))
|
||||
|
||||
base_cmd = ["accelerate", "launch"]
|
||||
base_cmd.extend(accelerate_args)
|
||||
base_cmd.extend(["-m", "axolotl.cli.train"])
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud, config=config, accelerate=False, **kwargs
|
||||
)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
finally:
|
||||
# Only delete temp files, not the original config
|
||||
if cfg_file != config:
|
||||
os.unlink(cfg_file)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU evaluation",
|
||||
)
|
||||
@add_options_from_dataclass(EvaluateCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
|
||||
"""
|
||||
Evaluate a model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.evaluate"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -218,30 +180,42 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=False,
|
||||
help="Use accelerate launch for multi-GPU inference",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for multi-GPU inference",
|
||||
)
|
||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
|
||||
"""
|
||||
Run inference with a trained model.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
|
||||
gradio: Whether to use Gradio browser interface or command line for inference.
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.inference"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
if gradio:
|
||||
@@ -254,33 +228,42 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||
do_cli(config=config, gradio=gradio, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(
|
||||
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
|
||||
)
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--accelerate/--no-accelerate",
|
||||
default=True,
|
||||
help="Use accelerate launch for weight merging",
|
||||
"--launcher",
|
||||
type=click.Choice(["accelerate", "torchrun", "python"]),
|
||||
default="accelerate",
|
||||
help="Launcher to use for weight merging",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@click.pass_context
|
||||
def merge_sharded_fsdp_weights(
|
||||
ctx: click.Context, config: str, launcher: str, **kwargs
|
||||
):
|
||||
"""
|
||||
Merge sharded FSDP model weights.
|
||||
|
||||
Args:
|
||||
ctx: Click context for extra args.
|
||||
config: Path to `axolotl` config YAML file.
|
||||
accelerate: Whether to use `accelerate` launcher.
|
||||
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
if accelerate:
|
||||
base_cmd = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||
]
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
if launcher in LAUNCHER_COMMAND_MAPPING:
|
||||
base_cmd = (
|
||||
LAUNCHER_COMMAND_MAPPING[launcher]
|
||||
+ launcher_args
|
||||
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
|
||||
)
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
@@ -296,7 +279,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
def merge_lora(config: str, **kwargs) -> None:
|
||||
def merge_lora(config: str, **kwargs):
|
||||
"""
|
||||
Merge trained LoRA adapters into a base model.
|
||||
|
||||
@@ -313,7 +296,7 @@ def merge_lora(config: str, **kwargs) -> None:
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
def fetch(directory: str, dest: Optional[str]):
|
||||
"""
|
||||
Fetch example configs or other resources.
|
||||
|
||||
@@ -351,7 +334,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
def delinearize_llama4(model: str, output: str) -> None:
|
||||
def delinearize_llama4(model: str, output: str):
|
||||
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
|
||||
|
||||
do_delinearize_llama4(model, output)
|
||||
@@ -365,5 +348,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
main()
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -88,5 +87,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -17,7 +17,6 @@ from accelerate.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_torch_version,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
@@ -204,5 +203,4 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -9,7 +9,6 @@ import fire
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.args import PreprocessCliArgs
|
||||
@@ -109,5 +108,4 @@ def do_cli(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Union
|
||||
|
||||
import fire
|
||||
from accelerate import Accelerator
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
@@ -16,7 +15,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -31,9 +29,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
@@ -122,5 +117,4 @@ def ray_train_func(kwargs: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
fire.Fire(do_cli)
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Utility methods for axolotl CLI."""
|
||||
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
cmd.append(f"--{key}")
|
||||
else:
|
||||
cmd.extend([f"--{key}", str(value)])
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "new"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||
config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
23
src/axolotl/cli/utils/__init__.py
Normal file
23
src/axolotl/cli/utils/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Init for axolotl.cli.utils module."""
|
||||
|
||||
from .args import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from .fetch import fetch_from_github
|
||||
from .load import load_model_and_tokenizer
|
||||
from .sweeps import generate_sweep_configs
|
||||
from .train import build_command, generate_config_files, launch_training
|
||||
|
||||
__all__ = [
|
||||
"filter_none_kwargs",
|
||||
"add_options_from_dataclass",
|
||||
"add_options_from_config",
|
||||
"build_command",
|
||||
"generate_config_files",
|
||||
"generate_sweep_configs",
|
||||
"load_model_and_tokenizer",
|
||||
"launch_training",
|
||||
"fetch_from_github",
|
||||
]
|
||||
120
src/axolotl/cli/utils/args.py
Normal file
120
src/axolotl/cli/utils/args.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Utilities for axolotl CLI args."""
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
Args:
|
||||
field_type: Type of field for Axolotl CLI command.
|
||||
|
||||
Returns:
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
return field_type
|
||||
|
||||
|
||||
def filter_none_kwargs(func: Callable) -> Callable:
|
||||
"""
|
||||
Wraps function to remove `None`-valued `kwargs`.
|
||||
|
||||
Args:
|
||||
func: Function to wrap.
|
||||
|
||||
Returns:
|
||||
Wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Callable:
|
||||
"""Filters out `None`-valued `kwargs`."""
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
return func(*args, **filtered_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
config_class: Dataclass with fields to parse from the CLI.
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process dataclass fields in reverse order for correct option ordering
|
||||
for field in reversed(dataclasses.fields(config_class)):
|
||||
field_type = _strip_optional_type(field.type)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{field.name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
help=field.metadata.get("description"),
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
Returns:
|
||||
Function decorator for Axolotl CLI command.
|
||||
"""
|
||||
|
||||
def decorator(function: Callable) -> Callable:
|
||||
# Process model fields in reverse order for correct option ordering
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{name.replace('_', '-')}"
|
||||
function = click.option(
|
||||
option_name, default=None, help=field.description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
return decorator
|
||||
142
src/axolotl/cli/utils/fetch.py
Normal file
142
src/axolotl/cli/utils/fetch.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Utilities for axolotl fetch CLI command."""
|
||||
|
||||
import concurrent.futures
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _download_file(
|
||||
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Download a single file and return its processing status.
|
||||
|
||||
Args:
|
||||
file_info: Tuple of (file_path, remote_sha).
|
||||
raw_base_url: Base URL for raw GitHub content.
|
||||
dest_path: Local destination directory.
|
||||
dir_prefix: Directory prefix to filter files.
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
|
||||
"""
|
||||
file_path, remote_sha = file_info
|
||||
raw_url = f"{raw_base_url}/{file_path}"
|
||||
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||
|
||||
# Check if file exists and needs updating
|
||||
if dest_file.exists():
|
||||
with open(dest_file, "rb") as file:
|
||||
content = file.read()
|
||||
# Calculate git blob SHA
|
||||
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||
|
||||
if local_sha == remote_sha:
|
||||
print(f"Skipping {file_path} (unchanged)")
|
||||
return file_path, "unchanged"
|
||||
|
||||
print(f"Updating {file_path}")
|
||||
status = "updated"
|
||||
else:
|
||||
print(f"Downloading {file_path}")
|
||||
status = "new"
|
||||
|
||||
# Create directories if needed
|
||||
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Download and save file
|
||||
try:
|
||||
response = requests.get(raw_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(dest_file, "wb") as file:
|
||||
file.write(response.content)
|
||||
|
||||
return file_path, status
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||
return file_path, "error"
|
||||
|
||||
|
||||
def fetch_from_github(
|
||||
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
|
||||
) -> None:
|
||||
"""
|
||||
Sync files from a specific directory in the GitHub repository.
|
||||
Only downloads files that don't exist locally or have changed.
|
||||
|
||||
Args:
|
||||
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
|
||||
'deepspeed_configs/').
|
||||
dest_dir: Local destination directory.
|
||||
max_workers: Maximum number of concurrent downloads.
|
||||
"""
|
||||
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||
|
||||
# Get repository tree with timeout
|
||||
response = requests.get(api_url, timeout=30)
|
||||
response.raise_for_status()
|
||||
tree = json.loads(response.text)
|
||||
|
||||
# Filter for files and get their SHA
|
||||
files = {
|
||||
item["path"]: item["sha"]
|
||||
for item in tree["tree"]
|
||||
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||
}
|
||||
|
||||
if not files:
|
||||
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||
|
||||
# Default destination directory is the last part of dir_prefix
|
||||
default_dest = Path(dir_prefix.rstrip("/"))
|
||||
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||
|
||||
# Keep track of processed files for summary
|
||||
files_processed: dict[str, list[str]] = {
|
||||
"new": [],
|
||||
"updated": [],
|
||||
"unchanged": [],
|
||||
"error": [],
|
||||
}
|
||||
|
||||
# Process files in parallel using ThreadPoolExecutor
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_file = {
|
||||
executor.submit(
|
||||
_download_file,
|
||||
(file_path, remote_sha),
|
||||
raw_base_url,
|
||||
dest_path,
|
||||
dir_prefix,
|
||||
): file_path
|
||||
for file_path, remote_sha in files.items()
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in concurrent.futures.as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
file_path, status = future.result()
|
||||
files_processed[status].append(file_path)
|
||||
except (requests.RequestException, IOError) as request_error:
|
||||
print(f"Error processing {file_path}: {str(request_error)}")
|
||||
files_processed["error"].append(file_path)
|
||||
|
||||
# Log summary
|
||||
LOG.info("\nSync Summary:")
|
||||
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||
if files_processed["error"]:
|
||||
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||
52
src/axolotl/cli/utils/load.py
Normal file
52
src/axolotl/cli/utils/load.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Utilities for model, tokenizer, etc. loading."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
inference: bool = False,
|
||||
) -> tuple[
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||
ProcessorMixin | None,
|
||||
]:
|
||||
"""
|
||||
Helper function for loading a model, tokenizer, and processor specified in the
|
||||
given `axolotl` config.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
inference: Boolean denoting inference mode.
|
||||
|
||||
Returns:
|
||||
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||
"""
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
LOG.info("loading processor...")
|
||||
processor = load_processor(cfg, tokenizer)
|
||||
|
||||
return model, tokenizer, processor
|
||||
188
src/axolotl/cli/utils/train.py
Normal file
188
src/axolotl/cli/utils/train.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Utilities for axolotl train CLI command."""
|
||||
|
||||
import os
|
||||
import subprocess # nosec
|
||||
import tempfile
|
||||
from typing import Any, Iterator, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli.utils.sweeps import generate_sweep_configs
|
||||
|
||||
|
||||
def _add_default_rdzv_args(launcher_args: list[str]) -> list[str]:
|
||||
"""
|
||||
Add default RDZV arguments if rdzv_endpoint is set but rdzv_backend/rdzv_id are missing.
|
||||
|
||||
Args:
|
||||
launcher_args: List of launcher arguments
|
||||
|
||||
Returns:
|
||||
Updated launcher args with defaults added if needed
|
||||
"""
|
||||
args = launcher_args.copy()
|
||||
|
||||
# Check if rdzv_endpoint is present
|
||||
has_rdzv_endpoint = any("--rdzv_endpoint" in arg for arg in args)
|
||||
|
||||
if has_rdzv_endpoint:
|
||||
# Check if rdzv_backend is already provided
|
||||
has_rdzv_backend = any("--rdzv_backend" in arg for arg in args)
|
||||
if not has_rdzv_backend:
|
||||
args.extend(["--rdzv_backend", "c10d"])
|
||||
|
||||
# Check if rdzv_id is already provided
|
||||
has_rdzv_id = any("--rdzv_id" in arg for arg in args)
|
||||
if not has_rdzv_id:
|
||||
import uuid
|
||||
|
||||
args.extend(["--rdzv_id", str(uuid.uuid4())[:8]])
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
|
||||
"""
|
||||
Build command list from base command and options.
|
||||
|
||||
Args:
|
||||
base_cmd: Command without options.
|
||||
options: Options to parse and append to base command.
|
||||
|
||||
Returns:
|
||||
List of strings giving shell command.
|
||||
"""
|
||||
cmd = base_cmd.copy()
|
||||
|
||||
for key, value in options.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
key = key.replace("_", "-")
|
||||
cmd.append(f"--{key}={value}")
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def generate_config_files(config: str, sweep: str | None) -> Iterator[str]:
|
||||
"""Generate list of configuration files to process."""
|
||||
if not sweep:
|
||||
yield config
|
||||
return
|
||||
|
||||
# Load sweep and base configurations
|
||||
with open(sweep, "r", encoding="utf-8") as fin:
|
||||
sweep_config: dict[str, list] = yaml.safe_load(fin)
|
||||
with open(config, "r", encoding="utf-8") as fin:
|
||||
base_config: dict[str, list] = yaml.safe_load(fin)
|
||||
|
||||
# Generate all possible configurations
|
||||
permutations = generate_sweep_configs(base_config, sweep_config)
|
||||
for permutation in permutations:
|
||||
# pylint: disable=consider-using-with
|
||||
temp_file = tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".yaml",
|
||||
delete=False,
|
||||
encoding="utf-8",
|
||||
)
|
||||
yaml.dump(permutation, temp_file)
|
||||
temp_file.close()
|
||||
yield temp_file.name
|
||||
|
||||
|
||||
def launch_training(
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
cloud: str | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training with the given configuration."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
if cloud:
|
||||
_launch_cloud_training(cloud, cfg_file, launcher, kwargs, launcher_args)
|
||||
elif launcher:
|
||||
if launcher == "accelerate":
|
||||
_launch_accelerate_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "torchrun":
|
||||
_launch_torchrun_training(cfg_file, kwargs, launcher_args)
|
||||
elif launcher == "python":
|
||||
_launch_python_training(cfg_file, kwargs)
|
||||
|
||||
|
||||
def _launch_cloud_training(
|
||||
cloud: str,
|
||||
cfg_file: str,
|
||||
launcher: Literal["accelerate", "torchrun", "python"] | None,
|
||||
kwargs: dict,
|
||||
launcher_args: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Execute training via cloud launcher."""
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
launcher_args = launcher_args or []
|
||||
cwd = os.getcwd() if launcher else None
|
||||
|
||||
do_cli_train(
|
||||
cloud_config=cloud,
|
||||
config=cfg_file,
|
||||
launcher=launcher or "accelerate",
|
||||
launcher_args=launcher_args,
|
||||
cwd=cwd,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _launch_accelerate_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via accelerate launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
internal_launcher_args = []
|
||||
|
||||
# Extract launcher-specific arguments from kwargs (legacy support)
|
||||
if "main_process_port" in kwargs:
|
||||
main_process_port = kwargs.pop("main_process_port")
|
||||
internal_launcher_args.extend(["--main_process_port", str(main_process_port)])
|
||||
|
||||
if "num_processes" in kwargs:
|
||||
num_processes = kwargs.pop("num_processes")
|
||||
internal_launcher_args.extend(["--num_processes", str(num_processes)])
|
||||
|
||||
# Combine internal args with user-provided launcher args
|
||||
all_launcher_args = internal_launcher_args + launcher_args
|
||||
|
||||
base_cmd = (
|
||||
["accelerate", "launch"] + all_launcher_args + ["-m", "axolotl.cli.train"]
|
||||
)
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_torchrun_training(
|
||||
cfg_file: str, kwargs: dict, launcher_args: list[str] | None = None
|
||||
) -> None:
|
||||
"""Execute training via torchrun launcher."""
|
||||
launcher_args = launcher_args or []
|
||||
|
||||
# Add default RDZV arguments if rdzv_endpoint is set
|
||||
launcher_args = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
base_cmd = ["torchrun"] + launcher_args + ["-m", "axolotl.cli.train"]
|
||||
if cfg_file:
|
||||
base_cmd.append(cfg_file)
|
||||
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
|
||||
|
||||
def _launch_python_training(cfg_file: str, kwargs: dict) -> None:
|
||||
"""Execute training via python launcher."""
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
@@ -17,16 +17,23 @@ class BaseCliTest:
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||
result = cli_runner.invoke(
|
||||
cli, [command, "nonexistent.yml", "--launcher", "python"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def _test_basic_execution(
|
||||
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||
self,
|
||||
cli_runner,
|
||||
tmp_path: Path,
|
||||
valid_test_config: str,
|
||||
command: str,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Test basic execution with accelerate.
|
||||
|
||||
@@ -35,6 +42,7 @@ class BaseCliTest:
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
train: Whether to test training (default) or evaluation
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
@@ -43,15 +51,21 @@ class BaseCliTest:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
|
||||
expected = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
f"axolotl.cli.{command}",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
"--debug=False",
|
||||
"--debug-text-only=False",
|
||||
"--debug-num-examples=0",
|
||||
]
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for evaluate CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
|
||||
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
|
||||
)
|
||||
|
||||
def test_evaluate_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
"2",
|
||||
"--sequence-len",
|
||||
"128",
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
cfg = mock_evaluate.call_args[0][0]
|
||||
assert cfg.micro_batch_size == 2
|
||||
assert cfg.sequence_len == 128
|
||||
|
||||
def test_evaluate_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing evaluate commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate"],
|
||||
["inference", str(config_path), "--launcher", "python"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||
["inference", str(config_path), "--launcher", "python", "--gradio"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
|
||||
"""Test inference with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
|
||||
"""Test inference with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
|
||||
"""Test inference with gradio and launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--gradio",
|
||||
"--",
|
||||
"--num_processes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify both gradio flag and launcher args are present
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--num_processes=2" in called_cmd
|
||||
assert "--gradio" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
|
||||
"""Test that existing inference commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -18,11 +18,10 @@ def test_build_command():
|
||||
assert result == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--learning-rate",
|
||||
"0.0001",
|
||||
"--batch-size",
|
||||
"8",
|
||||
"--debug",
|
||||
"--learning-rate=0.0001",
|
||||
"--batch-size=8",
|
||||
"--debug=True",
|
||||
"--use-fp16=False",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
|
||||
],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "No such option" in result.output
|
||||
assert "does not exist" in result.output
|
||||
|
||||
|
||||
def test_required_config_argument(cli_runner):
|
||||
|
||||
@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||
cli,
|
||||
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
from axolotl.cli.utils import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for train CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
|
||||
|
||||
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "train", train=True
|
||||
)
|
||||
|
||||
def test_train_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
"--learning-rate=1e-4",
|
||||
"--micro-batch-size=2",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest):
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
|
||||
def test_train_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing train commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
def test_train_mixed_args_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with both regular CLI args and launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--learning-rate",
|
||||
"2e-4",
|
||||
"--micro-batch-size",
|
||||
"4",
|
||||
"--",
|
||||
"--nproc_per_node=8",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
assert "--learning-rate=2e-4" in called_cmd
|
||||
assert "--micro-batch-size=4" in called_cmd
|
||||
|
||||
def test_train_cloud_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with cloud and launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
cloud_path = tmp_path / "cloud.yml"
|
||||
cloud_path.write_text("provider: modal\ngpu: a100")
|
||||
|
||||
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--cloud",
|
||||
str(cloud_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=4",
|
||||
"--nnodes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cloud_train.assert_called_once()
|
||||
|
||||
# Verify cloud training was called with launcher args
|
||||
call_kwargs = mock_cloud_train.call_args.kwargs
|
||||
assert call_kwargs["launcher"] == "torchrun"
|
||||
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]
|
||||
|
||||
@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
|
||||
with patch("requests.get", side_effect=requests.RequestException):
|
||||
with pytest.raises(requests.RequestException):
|
||||
fetch_from_github("examples/", None)
|
||||
|
||||
|
||||
def assert_launcher_args_in_command(
|
||||
mock_subprocess_call,
|
||||
launcher: str,
|
||||
expected_launcher_args: list[str],
|
||||
command_module: str,
|
||||
):
|
||||
"""
|
||||
Helper function to verify launcher arguments are properly passed in subprocess calls.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
expected_launcher_args: List of expected launcher arguments
|
||||
command_module: Expected module name (e.g., "axolotl.cli.train")
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
# Verify launcher
|
||||
assert (
|
||||
called_cmd[0] == launcher
|
||||
), f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
|
||||
# Verify launcher args are present
|
||||
for arg in expected_launcher_args:
|
||||
assert (
|
||||
arg in called_cmd
|
||||
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
|
||||
# Verify module is present
|
||||
assert "-m" in called_cmd, "Expected -m flag for module execution"
|
||||
assert (
|
||||
command_module in called_cmd
|
||||
), f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
|
||||
|
||||
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
"""
|
||||
Helper function to verify no unwanted launcher arguments are present.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
if launcher == "accelerate":
|
||||
# For accelerate, launcher args should be between 'launch' and '-m'
|
||||
launch_idx = called_cmd.index("launch")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[launch_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
elif launcher == "torchrun":
|
||||
# For torchrun, launcher args should be between 'torchrun' and '-m'
|
||||
torchrun_idx = called_cmd.index("torchrun")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_launcher_args():
|
||||
"""Fixture providing common launcher argument combinations for testing."""
|
||||
return {
|
||||
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
|
||||
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
|
||||
}
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_endpoint():
|
||||
"""Test that default RDZV args are added when rdzv_endpoint is present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should have added rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
# Original args should still be present
|
||||
assert "--nnodes=2" in result
|
||||
assert "--rdzv_endpoint=127.0.0.1:29400" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_backend():
|
||||
"""Test that existing rdzv_backend is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_backend
|
||||
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
|
||||
assert backend_count == 1
|
||||
assert "--rdzv_backend=static" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_id():
|
||||
"""Test that existing rdzv_id is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_id=my_job_123",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_id
|
||||
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
|
||||
assert id_count == 1
|
||||
assert "--rdzv_id=my_job_123" in result
|
||||
|
||||
# Should still add rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_without_endpoint():
|
||||
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any rdzv args
|
||||
assert "--rdzv_backend" not in result
|
||||
assert result == launcher_args
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_all_existing():
|
||||
"""Test that no defaults are added when all RDZV args are present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
"--rdzv_id=existing_job",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any additional args
|
||||
assert len(result) == len(launcher_args)
|
||||
assert result == launcher_args
|
||||
|
||||
Reference in New Issue
Block a user