CLI: add --launcher option, support launcher args, cleanup, refactor (#2924)

* add --launcher option; explicit True/False bool args; small cleanup

* refactor

* add torchrun, accelerate cli args

* add rdzv arg default + tests

* update _quarto

* coderabbit

* fix

* we can't set rdvz_id independently across nodes

* coderabbit

* fix tests
This commit is contained in:
Dan Saunders
2025-07-30 15:46:56 -04:00
committed by GitHub
parent 22810c97b7
commit bb1cae1a20
31 changed files with 1417 additions and 541 deletions

View File

@@ -17,16 +17,23 @@ class BaseCliTest:
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
result = cli_runner.invoke(
cli, [command, "nonexistent.yml", "--launcher", "python"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
self,
cli_runner,
tmp_path: Path,
valid_test_config: str,
command: str,
train: bool = True,
):
"""Test basic execution with accelerate.
@@ -35,6 +42,7 @@ class BaseCliTest:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
train: Whether to test training (default) or evaluation
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
@@ -43,15 +51,21 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
expected = [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
"--debug=False",
"--debug-text-only=False",
"--debug-num-examples=0",
]
if train:
expected.append("--shard=False")
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,5 +1,7 @@
"""Tests for evaluate CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
)
def test_evaluate_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
# pylint: disable=duplicate-code
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
"2",
"--sequence-len",
"128",
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
cfg = mock_evaluate.call_args[0][0]
assert cfg.micro_batch_size == 2
assert cfg.sequence_len == 128
def test_evaluate_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing evaluate commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--micro-batch-size",
"2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'

View File

@@ -1,5 +1,7 @@
"""pytest tests for axolotl CLI inference command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
["inference", str(config_path), "--launcher", "python"],
catch_exceptions=False,
)
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
["inference", str(config_path), "--launcher", "python", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
"""Test inference with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
"""Test inference with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
"""Test inference with gradio and launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--gradio",
"--",
"--num_processes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify both gradio flag and launcher args are present
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--num_processes=2" in called_cmd
assert "--gradio" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
"""Test that existing inference commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -18,11 +18,10 @@ def test_build_command():
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
"--learning-rate=0.0001",
"--batch-size=8",
"--debug=True",
"--use-fp16=False",
]
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
],
)
assert result.exit_code != 0
assert "No such option" in result.output
assert "does not exist" in result.output
def test_required_config_argument(cli_runner):

View File

@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
cli,
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
cli_runner, config_path
):
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -2,7 +2,7 @@
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
from axolotl.cli.utils import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():

View File

@@ -1,5 +1,7 @@
"""Tests for train CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "train", train=True
)
def test_train_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
"--learning-rate=1e-4",
"--micro-batch-size=2",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -73,3 +77,174 @@ class TestTrainCommand(BaseCliTest):
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2
def test_train_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing train commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--learning-rate",
"1e-4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'
def test_train_mixed_args_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with both regular CLI args and launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--learning-rate",
"2e-4",
"--micro-batch-size",
"4",
"--",
"--nproc_per_node=8",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
called_cmd = mock_subprocess.call_args.args[0]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present
assert "--learning-rate=2e-4" in called_cmd
assert "--micro-batch-size=4" in called_cmd
def test_train_cloud_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with cloud and launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
cloud_path = tmp_path / "cloud.yml"
cloud_path.write_text("provider: modal\ngpu: a100")
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--cloud",
str(cloud_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=4",
"--nnodes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_cloud_train.assert_called_once()
# Verify cloud training was called with launcher args
call_kwargs = mock_cloud_train.call_args.kwargs
assert call_kwargs["launcher"] == "torchrun"
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]

View File

@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)
def assert_launcher_args_in_command(
mock_subprocess_call,
launcher: str,
expected_launcher_args: list[str],
command_module: str,
):
"""
Helper function to verify launcher arguments are properly passed in subprocess calls.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
expected_launcher_args: List of expected launcher arguments
command_module: Expected module name (e.g., "axolotl.cli.train")
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
# Verify launcher
assert (
called_cmd[0] == launcher
), f"Expected launcher {launcher}, got {called_cmd[0]}"
# Verify launcher args are present
for arg in expected_launcher_args:
assert (
arg in called_cmd
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
# Verify module is present
assert "-m" in called_cmd, "Expected -m flag for module execution"
assert (
command_module in called_cmd
), f"Expected module {command_module} not found in command: {called_cmd}"
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
"""
Helper function to verify no unwanted launcher arguments are present.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
if launcher == "accelerate":
# For accelerate, launcher args should be between 'launch' and '-m'
launch_idx = called_cmd.index("launch")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[launch_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
elif launcher == "torchrun":
# For torchrun, launcher args should be between 'torchrun' and '-m'
torchrun_idx = called_cmd.index("torchrun")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
@pytest.fixture
def common_launcher_args():
"""Fixture providing common launcher argument combinations for testing."""
return {
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
}
def test_add_default_rdzv_args_with_endpoint():
"""Test that default RDZV args are added when rdzv_endpoint is present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
result = _add_default_rdzv_args(launcher_args)
# Should have added rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
# Original args should still be present
assert "--nnodes=2" in result
assert "--rdzv_endpoint=127.0.0.1:29400" in result
def test_add_default_rdzv_args_with_existing_backend():
"""Test that existing rdzv_backend is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_backend
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
assert backend_count == 1
assert "--rdzv_backend=static" in result
def test_add_default_rdzv_args_with_existing_id():
"""Test that existing rdzv_id is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_id=my_job_123",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_id
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
assert id_count == 1
assert "--rdzv_id=my_job_123" in result
# Should still add rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
def test_add_default_rdzv_args_without_endpoint():
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
result = _add_default_rdzv_args(launcher_args)
# Should not add any rdzv args
assert "--rdzv_backend" not in result
assert result == launcher_args
def test_add_default_rdzv_args_with_all_existing():
"""Test that no defaults are added when all RDZV args are present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
"--rdzv_id=existing_job",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add any additional args
assert len(result) == len(launcher_args)
assert result == launcher_args