Resolve merge conflicts: unify pretraining utils imports, add alias handling; fix rl.py per new RL dataset API; resolve config schema conflict and add sequence_len_overflow_handling field

This commit is contained in:
mhenrhcsen
2025-08-12 20:45:26 +02:00
603 changed files with 37614 additions and 14002 deletions

View File

@@ -17,16 +17,23 @@ class BaseCliTest:
command: Command to test (train/evaluate)
"""
# Test missing config file
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
assert result.exit_code != 0
# Test non-existent config file
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
result = cli_runner.invoke(
cli, [command, "nonexistent.yml", "--launcher", "python"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def _test_basic_execution(
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
self,
cli_runner,
tmp_path: Path,
valid_test_config: str,
command: str,
train: bool = True,
):
"""Test basic execution with accelerate.
@@ -35,24 +42,37 @@ class BaseCliTest:
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
command: Command to test (train/evaluate)
train: Whether to test training (default) or evaluation
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock:
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
with patch(mock_fn) as mock:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
expected = [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
"--debug=False",
"--debug-text-only=False",
"--debug-num-examples=0",
]
assert mock.call_args.kwargs == {"check": True}
if train:
expected.append("--shard=False")
if command == "train":
assert mock.call_args.args[0] == "accelerate"
assert mock.call_args.args[1] == expected
else:
assert mock.call_args.args[0] == expected
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):

View File

@@ -1,5 +1,7 @@
"""Tests for evaluate CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
)
def test_evaluate_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
# pylint: disable=duplicate-code
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
"2",
"--sequence-len",
"128",
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
cfg = mock_evaluate.call_args[0][0]
assert cfg.micro_batch_size == 2
assert cfg.sequence_len == 128
def test_evaluate_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test evaluate with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.evaluate" in called_cmd
def test_evaluate_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing evaluate commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"evaluate",
str(config_path),
"--launcher",
"accelerate",
"--micro-batch-size",
"2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'

View File

@@ -1,5 +1,7 @@
"""pytest tests for axolotl CLI inference command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate"],
["inference", str(config_path), "--launcher", "python"],
catch_exceptions=False,
)
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
result = cli_runner.invoke(
cli,
["inference", str(config_path), "--no-accelerate", "--gradio"],
["inference", str(config_path), "--launcher", "python", "--gradio"],
catch_exceptions=False,
)
assert mock.called
assert result.exit_code == 0
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
"""Test inference with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
"""Test inference with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
"""Test inference with gradio and launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
"--gradio",
"--",
"--num_processes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify both gradio flag and launcher args are present
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--num_processes=2" in called_cmd
assert "--gradio" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.inference" in called_cmd
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
"""Test that existing inference commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"inference",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -18,11 +18,10 @@ def test_build_command():
assert result == [
"accelerate",
"launch",
"--learning-rate",
"0.0001",
"--batch-size",
"8",
"--debug",
"--learning-rate=0.0001",
"--batch-size=8",
"--debug=True",
"--use-fp16=False",
]
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
],
)
assert result.exit_code != 0
assert "No such option" in result.output
assert "does not exist" in result.output
def test_required_config_argument(cli_runner):

View File

@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command without accelerate"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
cli,
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
cli_runner, config_path
):
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
cli_runner, config_path
):
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
with patch("subprocess.run") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--launcher",
"accelerate",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
called_cmd = mock_subprocess.call_args.args[0]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'

View File

@@ -2,7 +2,7 @@
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
from axolotl.cli.utils import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():

View File

@@ -1,5 +1,7 @@
"""Tests for train CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
"""Test basic successful execution"""
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
self._test_basic_execution(
cli_runner, tmp_path, valid_test_config, "train", train=True
)
def test_train_basic_execution_no_accelerate(
self, cli_runner, tmp_path, valid_test_config
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--no-accelerate",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
[
"train",
str(config_path),
"--learning-rate",
"1e-4",
"--micro-batch-size",
"2",
"--no-accelerate",
"--learning-rate=1e-4",
"--micro-batch-size=2",
"--launcher",
"python",
],
catch_exceptions=False,
)
@@ -73,3 +77,177 @@ class TestTrainCommand(BaseCliTest):
cfg = mock_train.call_args[1]["cfg"]
assert cfg["learning_rate"] == 1e-4
assert cfg["micro_batch_size"] == 2
def test_train_with_launcher_args_torchrun(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with torchrun launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=2",
"--nnodes=1",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to torchrun
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "torchrun"
assert "--nproc_per_node=2" in called_cmd
assert "--nnodes=1" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_with_launcher_args_accelerate(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with accelerate launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--",
"--config_file=accelerate_config.yml",
"--num_processes=4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify launcher args are passed to accelerate
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
assert "--config_file=accelerate_config.yml" in called_cmd
assert "--num_processes=4" in called_cmd
assert "-m" in called_cmd
assert "axolotl.cli.train" in called_cmd
def test_train_backward_compatibility_no_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test that existing train commands work without launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"accelerate",
"--learning-rate",
"1e-4",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
# Verify no launcher args contamination
assert mock_subprocess.call_args.args[0] == "accelerate"
called_cmd = mock_subprocess.call_args.args[1]
assert called_cmd[0] == "accelerate"
assert called_cmd[1] == "launch"
# Should not contain any extra launcher args
launcher_section = called_cmd[2 : called_cmd.index("-m")]
assert (
len(launcher_section) == 0
) # No launcher args between 'launch' and '-m'
def test_train_mixed_args_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with both regular CLI args and launcher args"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch("os.execvpe") as mock_subprocess:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--launcher",
"torchrun",
"--learning-rate",
"2e-4",
"--micro-batch-size",
"4",
"--",
"--nproc_per_node=8",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_subprocess.assert_called_once()
assert mock_subprocess.call_args.args[0] == "torchrun"
called_cmd = mock_subprocess.call_args.args[1]
# Verify launcher args
assert "--nproc_per_node=8" in called_cmd
# Verify axolotl args are also present
assert "--learning-rate=2e-4" in called_cmd
assert "--micro-batch-size=4" in called_cmd
def test_train_cloud_with_launcher_args(
self, cli_runner, tmp_path, valid_test_config
):
"""Test train with cloud and launcher arguments"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
cloud_path = tmp_path / "cloud.yml"
cloud_path.write_text("provider: modal\ngpu: a100")
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
result = cli_runner.invoke(
cli,
[
"train",
str(config_path),
"--cloud",
str(cloud_path),
"--launcher",
"torchrun",
"--",
"--nproc_per_node=4",
"--nnodes=2",
],
catch_exceptions=False,
)
assert result.exit_code == 0
mock_cloud_train.assert_called_once()
# Verify cloud training was called with launcher args
call_kwargs = mock_cloud_train.call_args.kwargs
assert call_kwargs["launcher"] == "torchrun"
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]

View File

@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
with patch("requests.get", side_effect=requests.RequestException):
with pytest.raises(requests.RequestException):
fetch_from_github("examples/", None)
def assert_launcher_args_in_command(
mock_subprocess_call,
launcher: str,
expected_launcher_args: list[str],
command_module: str,
):
"""
Helper function to verify launcher arguments are properly passed in subprocess calls.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
expected_launcher_args: List of expected launcher arguments
command_module: Expected module name (e.g., "axolotl.cli.train")
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
# Verify launcher
assert (
called_cmd[0] == launcher
), f"Expected launcher {launcher}, got {called_cmd[0]}"
# Verify launcher args are present
for arg in expected_launcher_args:
assert (
arg in called_cmd
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
# Verify module is present
assert "-m" in called_cmd, "Expected -m flag for module execution"
assert (
command_module in called_cmd
), f"Expected module {command_module} not found in command: {called_cmd}"
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
"""
Helper function to verify no unwanted launcher arguments are present.
Args:
mock_subprocess_call: The mock subprocess.run call
launcher: Expected launcher ("accelerate", "torchrun", etc.)
"""
assert mock_subprocess_call.called, "subprocess.run should have been called"
called_cmd = mock_subprocess_call.call_args.args[0]
if launcher == "accelerate":
# For accelerate, launcher args should be between 'launch' and '-m'
launch_idx = called_cmd.index("launch")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[launch_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
elif launcher == "torchrun":
# For torchrun, launcher args should be between 'torchrun' and '-m'
torchrun_idx = called_cmd.index("torchrun")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
@pytest.fixture
def common_launcher_args():
"""Fixture providing common launcher argument combinations for testing."""
return {
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
}
def test_add_default_rdzv_args_with_endpoint():
"""Test that default RDZV args are added when rdzv_endpoint is present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
result = _add_default_rdzv_args(launcher_args)
# Should have added rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
# Original args should still be present
assert "--nnodes=2" in result
assert "--rdzv_endpoint=127.0.0.1:29400" in result
def test_add_default_rdzv_args_with_existing_backend():
"""Test that existing rdzv_backend is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_backend
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
assert backend_count == 1
assert "--rdzv_backend=static" in result
def test_add_default_rdzv_args_with_existing_id():
"""Test that existing rdzv_id is not overridden."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_id=my_job_123",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add another rdzv_id
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
assert id_count == 1
assert "--rdzv_id=my_job_123" in result
# Should still add rdzv_backend
assert "--rdzv_backend" in result
assert "c10d" in result
def test_add_default_rdzv_args_without_endpoint():
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
result = _add_default_rdzv_args(launcher_args)
# Should not add any rdzv args
assert "--rdzv_backend" not in result
assert result == launcher_args
def test_add_default_rdzv_args_with_all_existing():
"""Test that no defaults are added when all RDZV args are present."""
from axolotl.cli.utils.train import _add_default_rdzv_args
launcher_args = [
"--nnodes=2",
"--rdzv_endpoint=127.0.0.1:29400",
"--rdzv_backend=static",
"--rdzv_id=existing_job",
]
result = _add_default_rdzv_args(launcher_args)
# Should not add any additional args
assert len(result) == len(launcher_args)
assert result == launcher_args

View File

@@ -4,26 +4,33 @@ shared pytest fixtures
import functools
import importlib
import logging
import os
import shutil
import sys
import tempfile
import time
from pathlib import Path
from typing import Generator
import datasets
import pytest
import requests
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from tokenizers import AddedToken
from transformers import AutoTokenizer
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import (
enable_hf_offline,
hf_offline_context,
)
logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
@@ -411,7 +418,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
@pytest.fixture
def temp_dir():
def temp_dir() -> Generator[str, None, None]:
# Create a temporary directory
_temp_dir = tempfile.mkdtemp()
yield _temp_dir
@@ -419,6 +426,11 @@ def temp_dir():
shutil.rmtree(_temp_dir)
@pytest.fixture(scope="function", autouse=True)
def torch_manual_seed():
torch.manual_seed(42)
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
@@ -529,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture(name="min_base_cfg")
def fixture_min_base_cfg():
return DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
learning_rate=1e-3,
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
micro_batch_size=1,
gradient_accumulation_steps=1,
)
# # pylint: disable=redefined-outer-name,unused-argument
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",

604
tests/core/test_builders.py Normal file
View File

@@ -0,0 +1,604 @@
"""Unit tests for axolotl.core.builders"""
# pylint: disable=protected-access
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.data import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
@pytest.fixture(name="base_cfg")
def fixture_base_cfg():
"""
Base config with all common arguments between SFT and RLHF
"""
cfg = DictDefault(
{
# Model and tokenizer settings
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
"sequence_len": 2048,
"model_config_type": "llama", # example type
# Basic training settings
"micro_batch_size": 2,
"eval_batch_size": 2,
"num_epochs": 1,
"gradient_accumulation_steps": 1,
"max_steps": 100,
"val_set_size": 0,
# Optimizer settings
"optimizer": "adamw_torch_fused",
"learning_rate": 0.00005,
"weight_decay": 0.01,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"max_grad_norm": 1.0,
# LR scheduler settings
"lr_scheduler": "cosine",
"lr_scheduler_kwargs": {"foo": "bar"},
"warmup_steps": 10,
"warmup_ratio": None,
"cosine_min_lr_ratio": 0.1,
"cosine_constant_lr_ratio": 0.2,
# Checkpointing and saving
"save_steps": 100,
"output_dir": "./model-out",
"save_safetensors": True,
"save_total_limit": 4,
"save_only_model": False,
# Hardware/performance settings
"gradient_checkpointing": False,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"dataloader_prefetch_factor": 2,
"context_parallel_size": 1,
"tensor_parallel_size": 1,
# Dtype
"fp16": False,
"bf16": False,
"tf32": False,
# Logging and evaluation
"logging_steps": 10,
"eval_steps": 50,
"eval_strategy": "steps",
"save_strategy": "steps",
"include_tokens_per_second": True,
# Other common settings
"seed": 42,
"remove_unused_columns": True,
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
"dataset_processes": 4,
}
)
normalize_config(cfg)
return cfg
@pytest.fixture(name="dpo_cfg")
def fixture_dpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
"beta": 0.1, # DPO beta
}
)
return cfg
@pytest.fixture(name="orpo_cfg")
def fixture_orpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.ORPO,
"orpo_alpha": 0.1,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="kto_cfg")
def fixture_kto_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.KTO,
"kto_desirable_weight": 1.0,
"kto_undesirable_weight": 1.0,
"max_prompt_len": 512,
}
)
return cfg
@pytest.fixture(name="grpo_cfg")
def fixture_grpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.GRPO,
"trl": DictDefault(
{
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": False, # run on CPU
# "vllm_device": "auto",
# "vllm_gpu_memory_utilization": 0.15,
"num_generations": 4,
"reward_funcs": ["rewards.rand_reward_func"],
}
),
# Must be evenly divisible by num_generations
"micro_batch_size": 4,
}
)
return cfg
@pytest.fixture(name="ipo_cfg")
def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.IPO,
"dpo_label_smoothing": 0,
"beta": 0.1,
}
)
return cfg
@pytest.fixture(name="simpo_cfg")
def fixture_simpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": RLType.SIMPO,
"rl_beta": 0.2,
"cpo_alpha": 0.9,
"simpo_gamma": 0.4,
}
)
return cfg
@pytest.fixture(name="sft_cfg")
def fixture_sft_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": None,
"sample_packing": False,
"eval_sample_packing": False,
"flash_attention": False,
}
)
return cfg
@pytest.fixture(name="rm_cfg")
def fixture_rm_cfg(sft_cfg):
cfg = sft_cfg.copy()
cfg.update(
DictDefault(
{
"reward_model": True,
"datasets": [
{
"path": "argilla/distilabel-intel-orca-dpo-pairs",
"type": "bradley_terry.chat_template",
"split": "train[:1%]",
}
],
}
)
)
return cfg
@pytest.fixture(name="prm_cfg")
def fixture_prm_cfg(sft_cfg):
cfg = sft_cfg.copy()
cfg.update(
DictDefault(
{
"process_reward_model": True,
"datasets": [
{
"path": "trl-lib/math_shepherd",
"type": "stepwise_supervised",
"split": "train[:1%]",
}
],
}
)
)
return cfg
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(base_cfg):
return load_tokenizer(base_cfg)
@pytest.fixture(name="model")
def fixture_model(base_cfg, tokenizer):
model, _ = ModelLoader(base_cfg, tokenizer).load()
return model
class TestHFRLTrainerBuilder:
"""
TestCase class for RLHF trainer builders
"""
def _test_common_training_arguments(self, training_arguments, rl: str):
"""Helper to test common arguments across all variants"""
# Basic training settings
if rl == "grpo":
# grpo_cfg's micro_batch_size is diff from others
assert training_arguments.per_device_train_batch_size == 4
else:
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
# Optimizer settings
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
# LR scheduler settings
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.cosine_constant_lr_ratio == 0.2
# Other settings
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
# TODO(wing): restore once trl releases 0.22.0
# assert training_arguments.gradient_checkpointing is True
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
# DPO specific
assert training_arguments.beta == 0.1
assert hasattr(training_arguments, "use_weighting")
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
# ORPO specific
assert training_arguments.beta == 0.1 # maps from orpo_alpha
assert training_arguments.max_prompt_length == 512
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
# KTO specific
assert training_arguments.desirable_weight == 1.0
assert training_arguments.undesirable_weight == 1.0
assert training_arguments.max_prompt_length == 512
def _write_rewards_file(self, rewards_dir: Path):
"""
Writes reward function to local tmp path to be loaded on trainer building
"""
# Create rewards.py in a directory we can import from
rewards_dir.mkdir()
rewards_file = rewards_dir / "rewards.py"
rewards_file.write_text(
"""import random
def rand_reward_func(prompts, completions) -> list[float]:
return [random.uniform(0, 1) for _ in completions]
"""
)
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
rewards_dir = tmp_path / "rewards_test"
self._write_rewards_file(rewards_dir)
# Add the directory to Python path so we can import the module
sys.path.insert(0, str(rewards_dir))
try:
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
# GRPO specific
assert training_arguments.beta == 0.001
assert training_arguments.max_completion_length == 256
assert training_arguments.use_vllm is False
# assert training_arguments.vllm_device == "auto"
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
assert training_arguments.num_generations == 4
# Test trainer creation to verify reward_funcs
trainer = builder.build(100)
# Verify reward functions are properly loaded
assert len(trainer.reward_funcs) == 1
assert trainer.reward_funcs[0].__module__ == "rewards"
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
finally:
# remove imported module from path
if str(rewards_dir) in sys.path:
sys.path.remove(str(rewards_dir))
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
# IPO specific
assert training_arguments.beta == 0.1
assert training_arguments.loss_type == "ipo"
assert training_arguments.label_smoothing == 0
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(simpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
self._test_common_training_arguments(training_arguments, rl=simpo_cfg.rl)
# SIMPO specific
assert training_arguments.beta == 0.2
assert training_arguments.cpo_alpha == 0.9
assert training_arguments.simpo_gamma == 0.4
@pytest.mark.parametrize(
("cfg_string", "dataset_name"),
[
(
"dpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
(
"ipo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
(
"grpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
("kto_cfg", None), # no fixture for kto
(
"simpo_cfg",
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
),
],
)
def test_custom_optimizer_cls_and_kwargs(
self,
request,
cfg_string,
dataset_name,
tmp_path,
model,
tokenizer,
):
cfg = request.getfixturevalue(cfg_string)
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"
if cfg_string in ["dpo_cfg", "ipo_cfg", "grpo_cfg", "simpo_cfg"]:
cfg["datasets"] = [DictDefault(ALPACA_MESSAGES_CONFIG_REVISION)]
elif cfg_string == "kto_cfg":
cfg["datasets"] = [
DictDefault(
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
"type": "llama3.ultra",
"split": "train[:1%]",
}
)
]
elif cfg_string == "orpo_cfg":
cfg["datasets"] = [
DictDefault(
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
"type": "chat_template.argilla",
"split": "train[:1%]",
}
)
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
cfg["dataset_processes"] = 4
if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"
self._write_rewards_file(rewards_dir)
# Add the directory to Python path so we can import the module
sys.path.insert(0, str(rewards_dir))
try:
# Only use mock for the commented out configs
if dataset_name is not None:
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
mock_load_dataset.return_value = request.getfixturevalue(
dataset_name
)
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
else:
# Load actual datasets for orpo_cfg and kto_cfg
train_dataset, eval_dataset = prepare_preference_datasets(
cfg, tokenizer
)
builder.train_dataset = train_dataset
builder.eval_dataset = eval_dataset
trainer = builder.build(100)
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
assert optimizer_kwargs["lr"] == 0.00005
assert optimizer_kwargs["weight_decay"] == 0.01
assert optimizer_kwargs["betas"] == (0.998, 0.9)
assert optimizer_kwargs["eps"] == 0.00001
# Ensure optimizer is created with correct class
optim = trainer.create_optimizer()
assert isinstance(optim, Muon)
finally:
# remove imported module from path
if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
sys.path.remove(str(rewards_dir))
class TestHFCausalTrainerBuilder:
"""
TestCase class for SFT trainer builder
"""
def test_training_arguments(self, sft_cfg, model, tokenizer):
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
trainer = builder.build(100)
training_arguments = trainer.args
# Test common arguments
assert training_arguments.per_device_train_batch_size == 2
assert training_arguments.gradient_accumulation_steps == 1
assert training_arguments.max_steps == 100
assert training_arguments.learning_rate == 0.00005
assert training_arguments.weight_decay == 0.01
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.max_grad_norm == 1.0
assert training_arguments.lr_scheduler_type == "cosine"
assert training_arguments.warmup_steps == 10
assert training_arguments.cosine_min_lr_ratio == 0.1
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
assert training_arguments.gradient_checkpointing is False
# SFT specific
assert training_arguments.sample_packing is False
assert training_arguments.eval_sample_packing is False
@pytest.mark.parametrize(
"cfg_string",
[
"sft_cfg",
"rm_cfg",
"prm_cfg",
],
)
def test_custom_optimizer_cls_and_kwargs(
self, request, cfg_string, model, tokenizer
):
cfg = request.getfixturevalue(cfg_string)
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
cfg["optimizer"] = "muon"
# need to load datasets for reward model and process reward model trainer
if cfg_string in ["rm_cfg", "prm_cfg"]:
dataset_meta = load_datasets(cfg=cfg)
builder.train_dataset = dataset_meta.train_dataset
builder.eval_dataset = dataset_meta.eval_dataset
trainer = builder.build(100)
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
assert optimizer_kwargs["lr"] == 0.00005
assert optimizer_kwargs["weight_decay"] == 0.01
assert optimizer_kwargs["betas"] == (0.998, 0.9)
assert optimizer_kwargs["eps"] == 0.00001
# Ensure optimizer is created with correct class
optim = trainer.create_optimizer()
assert isinstance(optim, Muon)
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg = kto_cfg.copy()
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
try:
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)
except TypeError as e:
# Error raised if trainer_cls is None
assert "'tuple' object has no attribute 'config'" not in str(e)
except Exception: # pylint: disable=broad-exception-caught
# Another error happens, so we passed trainer_cls to builder
pass

View File

@@ -1,90 +0,0 @@
"""Unit tests for axolotl.core.trainer_builder"""
import pytest
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.00005,
"save_steps": 100,
"output_dir": "./model-out",
"warmup_steps": 10,
"gradient_checkpointing": False,
"optimizer": "adamw_torch_fused",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
"model_config_type": "llama",
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
normalize_config(cfg)
return cfg
@pytest.fixture(name="tokenizer")
def fixture_tokenizer(cfg):
return load_tokenizer(cfg)
@pytest.fixture(name="model")
def fixture_model(cfg, tokenizer):
return ModelLoader(cfg, tokenizer).load()
class TestHFRLTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
cfg.rl = RLType.KTO
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
with pytest.raises(
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)

View File

@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
@@ -45,6 +44,7 @@ def min_cfg(temp_dir):
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
@@ -59,8 +59,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -100,13 +99,13 @@ class TestCutCrossEntropyIntegration:
"save_safetensors": True,
"max_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
@@ -134,8 +133,7 @@ class TestCutCrossEntropyIntegration:
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):

View File

@@ -0,0 +1,62 @@
"""
Simple end-to-end smoke tests for FP8 mixed precision training
"""
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0
class FP8IntegrationTestCase:
"""
e2e smoke tests for FP8 mixed precision training with Axolotl
"""
@require_torch_2_7_0
def test_fp8_single_gpu_smoke(self, temp_dir):
"""Smoke test for single GPU FP8 + torch.compile training"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3, # Very short smoke test
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"sdp_attention": True,
"pad_to_seq_len": True,
"sample_packing": True,
"fp8": True,
"torch_compile": True,
"save_safetensors": True,
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
import os
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.integrations.base import BasePlugin
from axolotl.train import train
@@ -154,14 +153,14 @@ class TestPluginHooks:
"max_steps": 5,
"flash_attention": True,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,11 +5,9 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@@ -18,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"base_model": "Qwen/Qwen3-0.6B",
"tokenizer_config": "winglian/qwen3-14b-math",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
@@ -32,20 +30,22 @@ def min_cfg(temp_dir):
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
"kd_beta": 0.0,
"kd_normalize_topk": True,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
"type": "chat_template",
"split": "train",
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
"split_thinking": True,
"eot_tokens": ["<|im_end|>"],
"data_files": ["train/batch-000000.parquet"],
},
],
"skip_prepare_dataset": True,
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
"save_first_step": False,
}
@@ -81,18 +82,29 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
@@ -112,13 +124,22 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
train(cfg=cfg, dataset_meta=dataset_meta)
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"1",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"

View File

@@ -2,7 +2,6 @@
Simple end-to-end test for Liger integration
"""
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -51,14 +50,14 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -98,14 +97,14 @@ class LigerIntegrationTestCase:
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -82,14 +81,14 @@ class TestLLMCompressorIntegration:
},
"save_compressed": save_compressed,
},
"save_first_step": False,
}
)
prepare_plugins(cfg)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
try:
train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -1,7 +1,6 @@
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import pytest
import torch
import torch.nn.functional as F
@@ -20,8 +19,15 @@ def test_geglu_forward_shape():
assert out.device == gate.device
def test_geglu_forward_values():
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_forward_values(torch_seed):
"""Test GEGLU forward pass matches PyTorch reference implementation."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
@@ -34,8 +40,15 @@ def test_geglu_forward_values():
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
@pytest.mark.flaky(retries=1, delay=5)
@pytest.mark.parametrize(
"torch_seed",
[0, 42],
)
def test_geglu_backward(torch_seed):
"""Test GEGLU backward pass matches PyTorch autograd."""
torch.manual_seed(torch_seed)
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")

View File

@@ -64,6 +64,7 @@ def sample_tensors():
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
),
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
"scale": 0.5,
"shapes": {
"batch": batch_size,
@@ -103,23 +104,24 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
assert b.shape == (128,)
assert A.shape == (8, 64)
assert B.shape == (128, 8)
assert s == 0.5
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
"""Tests matmul_lora function"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test base matmul
out1 = matmul_lora(X, W, None, None, None, None)
expected1 = torch.matmul(X, W.t())
out1 = matmul_lora(X, W, b, None, None, None, None)
matmul = torch.matmul(X, W.t())
expected1 = matmul + b
assert torch.allclose(out1, expected1, rtol=1e-3)
# Test with LoRA
out2 = matmul_lora(X, W, None, A, B, scale)
out2 = matmul_lora(X, W, b, None, A, B, scale)
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
expected2 = expected1 + lora_term
expected2 = matmul + lora_term + b
assert torch.allclose(out2, expected2, rtol=1e-3)
# Test 3D input reshaping
X_3d = X.clone()
out3 = matmul_lora(X_3d, W, None, A, B, scale)
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
output = LoRA_MLP.apply(
X,
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True
# Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
None,
None,
None,
None,
k_weight,
None,
None,
None,
None,
None,
v_weight,
None,
None,
None,
None,
None,
True,
)
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
X,
q_weight,
None,
None,
q_A,
q_B,
scale,
k_weight,
None,
None,
k_A,
k_B,
scale,
v_weight,
None,
None,
v_A,
v_B,
scale,
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
"""Tests LoRA output projection"""
X = sample_tensors["X"]
W = sample_tensors["W"]
b = sample_tensors["b"]
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(X, W, None, A, B, scale)
output = LoRA_O.apply(X, W, b, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
"""Tests LoRA with quantized weights"""
X = sample_tensors["X"] # [batch, seq, hidden]
W = sample_tensors["W"] # [out, hidden]
b = sample_tensors["b"] # [out]
scale = 0.5
shapes = sample_tensors["shapes"]
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
# Test matmul with quantization
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
assert not torch.isnan(out).any()
# Test with different batch sizes
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
assert out2.shape == (4, 6, W.shape[0])
assert not torch.isnan(out2).any()
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
"""Tests various input shapes and dimensions"""
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
b = torch.randn(out, device="cuda", dtype=torch.float16)
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
scale = 0.5
result = matmul_lora(X, W, None, A, B, scale)
result = matmul_lora(X, W, b, None, A, B, scale)
assert result.shape == (batch, seq, out)
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
"""Tests gradient flow through LoRA layers"""
X = sample_tensors["X"].clone()
W = sample_tensors["W"].clone()
b = sample_tensors["b"].clone()
scale = sample_tensors["scale"]
shapes = sample_tensors["shapes"]
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
B.requires_grad = True
# Forward pass
out = matmul_lora(X, W, None, A, B, scale)
out = matmul_lora(X, W, b, None, A, B, scale)
loss = out.sum()
# Backward pass

View File

@@ -1,6 +1,5 @@
"""E2E tests for sequence parallelism"""
import os
from pathlib import Path
import pytest
@@ -12,8 +11,6 @@ from axolotl.utils.dict import DictDefault
from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"
class TestSequenceParallelism:
"""Test case for training with sequence parallelism enabled"""
@@ -57,6 +54,7 @@ class TestSequenceParallelism:
"micro_batch_size": micro_batch_size,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -69,8 +67,9 @@ class TestSequenceParallelism:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"sequence_parallel_degree": 2,
"context_parallel_size": 2,
"ring_attn_func": ring_attn_func,
"save_first_step": False,
}
)
@@ -94,7 +93,10 @@ class TestSequenceParallelism:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
temp_dir + "/runs",
"train/train_loss",
threshold,
"Train Loss (%s) is too high",
)
@pytest.mark.parametrize(
@@ -103,13 +105,13 @@ class TestSequenceParallelism:
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, True, "batch_zigzag", 2.5),
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
],
ids=[
"sample_packing, varlen_llama3 ring_attn_func",
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
],
)
def test_sequence_parallel_training(

View File

@@ -2,8 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
@@ -17,9 +15,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -59,12 +54,14 @@ class TestPackedFlex:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 2,
"use_tensorboard": True,
"save_strategy": "no",
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -90,5 +87,5 @@ class TestPackedFlex:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -4,7 +4,6 @@ GRPO test suite
import os
import random
import shutil
import subprocess # nosec B404
import sys
import tempfile
@@ -106,7 +105,7 @@ def start_vllm(
print(f"{i}: VLLM server failed to start: {str(exc)}")
# also check if the process.pid is still running
if not process.poll() is None:
if process.poll() is not None:
break
time.sleep(period_seconds)
@@ -118,7 +117,10 @@ def start_vllm(
recursive_kill(process)
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
print(log_file.read())
shutil.rmtree("/tmp/vllm.log")
try:
os.remove("/tmp/vllm.log")
except FileNotFoundError:
pass
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
# return the process
@@ -139,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
os.kill(process.pid, 9)
@pytest.mark.skip(reason="flaky vllm tests in modal")
class TestGRPO:
"""
Test case for GRPO training using multilpe GPUs
@@ -220,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -260,6 +264,101 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
**current_env,
},
)
finally:
(recursive_kill(vllm_process))
@require_vllm
def test_llama_lora_sp(self, temp_dir):
rnd_reward_suffix = str(random.randint(1000, 9999))
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"chat_template": "llama3",
"rl": "grpo",
"trl": {
"beta": 0.001,
"max_completion_length": 256,
"use_vllm": True,
"num_generations": 4,
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
},
"vllm": {
"max_model_len": 800,
"enable_prefix_caching": True,
},
"datasets": [
{
"path": "openai/gsm8k",
"name": "main",
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"context_parallel_size": 2,
"flash_attention": True,
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"max_steps": 3,
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
current_env = os.environ.copy()
env = {
"NCCL_P2P_LEVEL": "LOC",
**current_env,
"CUDA_VISIBLE_DEVICES": "1",
}
vllm_process = start_vllm(
cfg.base_model,
env=env,
quiet=True,
wait=300,
gpu_memory_utilization=0.15,
max_model_len=cfg.vllm.max_model_len,
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
host="0.0.0.0",
port=8000,
)
try:
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
str(2),
"--main-process-port",
f"{get_torch_dist_unique_port()}",
],
env={
"NCCL_P2P_LEVEL": "LOC",
"NCCL_DEBUG": "INFO",
**current_env,
},
)
finally:
recursive_kill(vllm_process)
@@ -305,12 +404,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
"warmup_steps": 10,
"val_set_size": 0.0,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)

View File

@@ -2,8 +2,6 @@
E2E tests for multigpu eval
"""
import logging
import os
from pathlib import Path
import yaml
@@ -14,9 +12,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -43,12 +38,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.004,
"val_set_size": 0.05,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -56,6 +52,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -70,6 +67,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -112,12 +110,13 @@ class TestMultiGPUEval:
"lora_dropout": 0.05,
"lora_target_linear": True,
"lora_modules_to_save": ["embed_tokens", "lm_head"],
"val_set_size": 0.0004,
"val_set_size": 0.01,
"special_tokens": {"pad_token": "<|endoftext|>"},
"datasets": [
{
"path": "teknium/GPT4-LLM-Cleaned",
"type": "alpaca",
"split": "train[:5%]",
},
],
"num_epochs": 1,
@@ -125,6 +124,7 @@ class TestMultiGPUEval:
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -139,6 +139,7 @@ class TestMultiGPUEval:
"logging_steps": 1,
"weight_decay": 0.0,
"use_tensorboard": True,
"save_first_step": False,
}
)

View File

@@ -0,0 +1,121 @@
"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_fp8_training_success(temp_dir):
"""Verify that FP8 training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
class TestFP8FSDP2:
"""Test class for FP8 mixed precision with FSDP2 functionality."""
@require_torch_2_7_0
@require_hopper
def test_fp8_fsdp2_smoke(self, temp_dir):
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3, # Very short smoke test
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused", # Use standard optimizer for stability
"lr_scheduler": "cosine",
"sdp_attention": True,
"pad_to_seq_len": True,
"sample_packing": True,
# FP8 configuration
"fp8": True,
"fp8_enable_fsdp_float8_all_gather": True,
"torch_compile": True,
# FSDP2 configuration
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"save_safetensors": True,
"save_first_step": False,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_fp8_training_success(temp_dir)

View File

@@ -0,0 +1,326 @@
"""Test module for FSDP1 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
class TestFSDP1:
"""Test class for FSDP1 functionality."""
@pytest.mark.parametrize(
"fsdp_cpu_ram_efficient_loading",
[True, False],
)
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@pytest.mark.parametrize(
"adapter_config",
[
{
"adapter": "lora",
"load_in_4bit": False,
},
{
"adapter": "qlora",
"load_in_4bit": True,
},
],
)
def test_lora_sft(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": adapter_config["adapter"],
"load_in_4bit": adapter_config["load_in_4bit"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@pytest.mark.parametrize(
"adapter_config",
[
{
"adapter": "lora",
"load_in_4bit": False,
},
{
"adapter": "qlora",
"load_in_4bit": True,
},
],
)
def test_dpo_lora(self, temp_dir, adapter_config):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"load_in_4bit": adapter_config["load_in_4bit"],
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,
"adapter": adapter_config["adapter"],
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": "1",
"fsdp_config": {
"fsdp_offload_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
},
"use_tensorboard": True,
"bf16": "auto",
"tf32": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -0,0 +1,482 @@
"""Test module for FSDP2 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
import pytest
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
class TestFSDP2:
"""Test class for FSDP2 functionality."""
@require_torch_2_7_0
@pytest.mark.parametrize(
"fsdp_cpu_ram_efficient_loading",
[True, False],
)
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
@pytest.mark.parametrize("peft_use_dora", [True, False])
def test_lora_sft(self, temp_dir, peft_use_dora):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"peft_use_dora": peft_use_dora,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_lora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_qlora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_qlora_sft_kernels(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_dpo_fft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_dpo_lora(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -2,8 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
@@ -16,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -69,12 +64,14 @@ class TestMultiGPUGemma3:
},
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.0001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"save_first_step": False,
}
)
@@ -96,5 +93,5 @@ class TestMultiGPUGemma3:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
)

View File

@@ -2,8 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
import pytest
@@ -18,9 +16,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -67,12 +62,14 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"save_first_step": False,
}
)
@@ -94,7 +91,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -132,12 +129,14 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"save_first_step": False,
}
)
@@ -159,7 +158,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
def test_dpo_lora_ddp(self, temp_dir):
@@ -205,6 +204,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -212,6 +212,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"save_first_step": False,
}
)
@@ -237,7 +238,7 @@ class TestMultiGPULlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)
def test_dpo_qlora_ddp(self, temp_dir):
@@ -283,6 +284,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
@@ -290,6 +292,7 @@ class TestMultiGPULlama:
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"save_first_step": False,
}
)
@@ -315,7 +318,7 @@ class TestMultiGPULlama:
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)
@pytest.mark.parametrize(
@@ -345,6 +348,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": gradient_accumulation_steps,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -364,6 +368,8 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
"seed": 42,
"save_first_step": False,
}
)
@@ -385,12 +391,15 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
"fsdp_state_dict_type",
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
[
"FULL_STATE_DICT",
# "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
],
)
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
# pylint: disable=duplicate-code
@@ -412,11 +421,13 @@ class TestMultiGPULlama:
},
],
"num_epochs": 1,
"max_steps": 2,
"max_steps": 3,
"save_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -436,6 +447,7 @@ class TestMultiGPULlama:
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -457,7 +469,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_6_0
@@ -496,6 +508,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
"gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
@@ -513,6 +526,7 @@ class TestMultiGPULlama:
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
},
"use_tensorboard": True,
"save_first_step": False,
}
)
if attention_backend == "flash":
@@ -538,7 +552,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
@@ -578,6 +592,7 @@ class TestMultiGPULlama:
"gradient_accumulation_steps": 2,
# "gradient_checkpointing": True,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
@@ -593,10 +608,11 @@ class TestMultiGPULlama:
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -618,7 +634,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -674,12 +690,14 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
"use_tensorboard": True,
"save_first_step": False,
**adapter,
}
)
@@ -702,7 +720,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -748,12 +766,15 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"seed": 42,
"save_first_step": False,
**adapter,
}
)
@@ -776,7 +797,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -822,12 +843,14 @@ class TestMultiGPULlama:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
"save_first_step": False,
**adapter,
}
)
@@ -850,7 +873,7 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
@pytest.mark.skip(
@@ -896,6 +919,7 @@ class TestMultiGPULlama:
"save_safetensors": True,
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -917,5 +941,5 @@ class TestMultiGPULlama:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
)

View File

@@ -0,0 +1,192 @@
"""Tests for FileLockLoader class."""
import tempfile
import threading
import time
from pathlib import Path
from unittest.mock import MagicMock, Mock, patch
import pytest
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.dict import DictDefault
class TestFileLockLoader:
"""Class with tests for FileLockLoader."""
@pytest.fixture
def temp_dir(self):
"""Create a temporary directory for testing."""
with tempfile.TemporaryDirectory() as tmp_dir:
yield Path(tmp_dir)
@pytest.fixture
def cfg(self, temp_dir):
"""Create a test configuration."""
return DictDefault({"dataset_prepared_path": str(temp_dir)})
@pytest.fixture
def loader(self, cfg):
"""Create a FileLockLoader instance for testing."""
return FileLockLoader(cfg)
def test_load_first_process(self, loader):
"""Test load() when no ready flag exists (first process)."""
mock_load_fn = Mock(return_value="test_data")
result = loader.load(mock_load_fn)
# Should call the load function
mock_load_fn.assert_called_once()
assert result == "test_data"
# Should create the ready flag
assert loader.ready_flag_path.exists()
def test_load_subsequent_process(self, loader):
"""Test load() when ready flag already exists (subsequent process)."""
# Create ready flag first
loader.ready_flag_path.touch()
mock_load_fn = Mock(return_value="loaded_data")
result = loader.load(mock_load_fn)
# Should still call load function (to load the prepared data)
mock_load_fn.assert_called_once()
assert result == "loaded_data"
def test_load_concurrent_processes(self, cfg):
"""Test that concurrent processes coordinate correctly."""
results = []
call_count = 0
def slow_load_fn():
nonlocal call_count
call_count += 1
time.sleep(0.1) # Simulate slow loading
return f"data_{call_count}"
def worker():
loader = FileLockLoader(cfg)
result = loader.load(slow_load_fn)
results.append(result)
# Start multiple threads simultaneously
threads = [threading.Thread(target=worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
# Only one thread should have done the initial loading
# All should return data, but the load function should be called
# once by the first process and once by each subsequent process
assert len(results) == 3
assert all(result.startswith("data_") for result in results)
@patch("time.sleep")
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
"""Test that processes wait for the ready flag to appear."""
mock_load_fn = Mock(return_value="waiting_data")
mock_ready_flag_path = Mock()
exists_call_count = 0
def mock_exists():
nonlocal exists_call_count
exists_call_count += 1
if exists_call_count == 1:
# First check: ready flag exists (not first process)
return True
if exists_call_count <= 3:
# While loop checks: flag doesn't exist yet
return False
return True
mock_ready_flag_path.exists.side_effect = mock_exists
# Replace the ready_flag_path with our mock
original_path = loader.ready_flag_path
loader.ready_flag_path = mock_ready_flag_path
try:
result = loader.load(mock_load_fn)
finally:
# Restore original path
loader.ready_flag_path = original_path
# Should have slept twice while waiting
assert mock_sleep.call_count == 2
mock_sleep.assert_called_with(1)
# Should eventually call load function
mock_load_fn.assert_called_once()
assert result == "waiting_data"
def test_complete_workflow_with_cleanup(self, loader):
"""Test the complete load -> cleanup workflow."""
mock_load_fn = Mock(return_value="test_data")
# First process calls load (this should set up counter)
result = loader.load(mock_load_fn)
assert result == "test_data"
assert loader.ready_flag_path.exists()
assert loader.counter_path.exists()
# Cleanup should remove everything since there's only one process
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_multiple_processes_workflow(self, loader):
"""Test workflow with multiple processes."""
# Simulate multiple processes by manually setting up counter
loader.ready_flag_path.touch()
loader.counter_path.write_text("3") # 3 processes
# First process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "2"
# Second process cleanup
loader.cleanup()
assert loader.ready_flag_path.exists()
assert loader.counter_path.read_text().strip() == "1"
# Last process cleanup
loader.cleanup()
assert not loader.ready_flag_path.exists()
assert not loader.counter_path.exists()
def test_load_exception_handling(self, loader):
"""Test behavior when load_fn raises an exception."""
def failing_load_fn():
raise ValueError("Load failed")
with pytest.raises(ValueError, match="Load failed"):
loader.load(failing_load_fn)
# Ready flag should not be created on failure
assert not loader.ready_flag_path.exists()
def test_file_lock_called(self, loader):
"""Test that FileLock is properly used."""
mock_load_fn = Mock(return_value="locked_data")
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
mock_context = MagicMock()
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
mock_filelock.return_value.__exit__ = Mock(return_value=None)
loader.load(mock_load_fn)
# Verify FileLock was called with correct path
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
# Verify context manager was used
mock_filelock.return_value.__enter__.assert_called_once()
mock_filelock.return_value.__exit__.assert_called_once()

View File

@@ -1,97 +0,0 @@
"""
E2E tests for multigpu qwen2
"""
import logging
import os
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
class TestMultiGPUQwen2:
"""
Test case for Llama models using LoRA
"""
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": base_model,
"load_in_4bit": True,
"rl": "dpo",
"chat_template": "chatml",
"sequence_len": 2048,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.01,
"datasets": [
{
"path": "Intel/orca_dpo_pairs",
"split": "train",
"type": "chatml.intel",
},
],
"num_epochs": 1,
"max_steps": 2,
"warmup_steps": 20,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"tf32": True,
# "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {
"use_reentrant": False,
},
"fsdp": [
"full_shard",
"auto_wrap",
],
"fsdp_config": {
"fsdp_limit_all_gathers": True,
"fsdp_offload_params": False,
"fsdp_sync_module_states": True,
"fsdp_use_orig_params": False,
"fsdp_cpu_ram_efficient_loading": False,
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_sharding_strategy": "FULL_SHARD",
},
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)

View File

@@ -2,8 +2,6 @@
E2E tests for multigpu post-training use Ray Train
"""
import logging
import os
from pathlib import Path
import pytest
@@ -12,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
os.environ["WANDB_DISABLED"] = "true"
from tests.e2e.utils import (
check_tensorboard,
require_torch_2_7_0,
require_torch_lt_2_6_0,
)
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
@@ -53,6 +52,7 @@ class TestMultiGPURay:
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
@@ -60,6 +60,7 @@ class TestMultiGPURay:
"use_tensorboard": True,
"use_ray": True,
"ray_num_workers": 2,
"save_first_step": False,
}
)
@@ -80,7 +81,7 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_lt_2_6_0
@@ -112,12 +113,14 @@ class TestMultiGPURay:
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
"use_tensorboard": True,
"save_first_step": False,
}
)
@@ -138,5 +141,73 @@ class TestMultiGPURay:
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)
@require_torch_2_7_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],
)
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"dataset_prepared_path": temp_dir + "/last_run_prepared",
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"save_first_step": False,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--use-ray",
"--ray-num-workers",
"2",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
)

View File

@@ -0,0 +1,69 @@
"""multigpu e2e test for tensor parallelism."""
from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
class TestTensorParallel:
"""Test class for Tensor Parallel functionality."""
@pytest.mark.skip(
reason="TP doesn't work with models with tied weights (embeddings)"
)
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"tensor_parallel_size": 2,
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
)

View File

@@ -21,8 +21,13 @@ from axolotl.kernels.lora import (
apply_lora_o,
apply_lora_qkv,
)
from axolotl.loaders.model import ModelLoader
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
find_self_attn_in_layer,
get_attention_cls_from_config,
get_layers,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
@@ -80,7 +85,7 @@ def small_llama_model():
)
def test_attention_patching_integration(model_name, attention_cls):
"""Test attention patching in integration context."""
cfg = {"base_model": model_name}
cfg = DictDefault({"base_model": model_name})
# Store the original implementation
original_forward = getattr(attention_cls, "forward")
@@ -391,7 +396,7 @@ def test_model_architecture(model_config):
# pylint: disable=duplicate-code
def test_kernel_training_integration():
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -421,6 +426,14 @@ def test_kernel_training_integration():
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Load model
model, _, _ = load_model_and_tokenizer(cfg=cfg)
@@ -466,3 +479,103 @@ def test_kernel_training_integration_auto_enable(temp_dir):
assert cfg.lora_mlp_kernel is True
assert cfg.lora_qkv_kernel is True
assert cfg.lora_o_kernel is True
# Get the attention class before patching to check for side effects
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load the model (this should trigger the patches)
tokenizer = load_tokenizer(cfg)
model, _ = ModelLoader(cfg, tokenizer).load()
# Test side effects of patch_self_attn_lora
assert hasattr(attention_cls, "_original_forward")
assert attention_cls.forward != original_forward_method
# Find at least one self-attention module and verify it has the patched methods
found_patched_attn = False
for layer in model.model.model.layers:
if hasattr(layer, "self_attn"):
self_attn = layer.self_attn
if all(
hasattr(self_attn, proj)
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
):
# These methods should be added by apply_lora_kernel_patches
assert hasattr(self_attn, "apply_qkv") and callable(self_attn.apply_qkv)
assert hasattr(self_attn, "apply_o") and callable(self_attn.apply_o)
found_patched_attn = True
break
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Write cfg to yaml file
path = Path(temp_dir) / "config.yaml"
with open(path, "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
model
)
# Verify patch was not applied
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")

View File

@@ -2,11 +2,8 @@
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class Test4dMultipackLlama(unittest.TestCase):
"""
@@ -61,12 +55,12 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -109,12 +103,12 @@ class Test4dMultipackLlama(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"fp16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import pytest
import transformers
from torch.utils.checkpoint import checkpoint
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -70,13 +69,14 @@ class TestActivationCheckpointing:
"bf16": True,
"save_safetensors": True,
"gradient_checkpointing": gradient_checkpointing,
"save_first_step": False,
"dataset_processes": 4,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,9 @@
E2E tests for lora llama
"""
import logging
import os
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +12,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFAXentropyLlama:
"""
@@ -69,6 +62,7 @@ class TestFAXentropyLlama:
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -79,12 +73,11 @@ class TestFAXentropyLlama:
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -2,13 +2,10 @@
E2E tests for falcon
"""
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFalconPatched(unittest.TestCase):
"""
@@ -64,12 +58,12 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -106,12 +100,12 @@ class TestFalconPatched(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,82 @@
"""
E2E tests for flattening batches
"""
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
class TestFAFlattening:
"""
Test case for Llama models using LoRA w batch flattening
"""
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 4],
)
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"batch_flattening": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"field_messages": "conversations",
"message_field_content": "value",
"message_field_role": "from",
"type": "chat_template",
"split": "train[:2%]",
},
],
"num_epochs": 1,
"max_steps": 5,
"save_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": gradient_accumulation_steps,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"use_tensorboard": True,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
)

View File

@@ -0,0 +1,131 @@
"""Integration tests for FSDP Params4bit patches."""
from unittest.mock import Mock, patch
import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)
@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance
class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""
def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)
# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)
assert isinstance(result, tuple)
assert len(result) == 2
# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize
# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])
with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)
# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()
class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""
@pytest.mark.integration
def test_all_patches_together(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"
# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"

View File

@@ -2,14 +2,11 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip("FIXME, mostly underused functionality")
class TestFusedLlama(unittest.TestCase):
@@ -35,7 +29,6 @@ class TestFusedLlama(unittest.TestCase):
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -59,6 +52,7 @@ class TestFusedLlama(unittest.TestCase):
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -67,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,10 @@
E2E tests for llama w/ S2 attn
"""
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="FIXME?")
class TestLlamaShiftedSparseAttention(unittest.TestCase):
@@ -64,13 +58,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -107,13 +101,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
"save_steps": 5,
"eval_steps": 5,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,14 +2,11 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLoraLlama(unittest.TestCase):
"""
@@ -61,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -70,8 +65,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -115,12 +109,12 @@ class TestLoraLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,20 +2,14 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
class TestMistral(unittest.TestCase):
@@ -23,6 +17,7 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA
"""
@require_torch_2_6_0
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
@@ -61,12 +56,12 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -103,12 +98,12 @@ class TestMistral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,11 +2,8 @@
E2E tests for mixtral
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMixtral(unittest.TestCase):
"""
@@ -58,12 +52,12 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -97,12 +91,12 @@ class TestMixtral(unittest.TestCase):
"save_steps": 3,
"eval_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
@@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
cfg = validate_config(cfg)

View File

@@ -49,6 +49,7 @@ class TestLlamaPeftEmbeddings:
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
"save_first_step": False,
}
)

View File

@@ -2,11 +2,8 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPhiMultipack(unittest.TestCase):
"""
@@ -60,13 +54,13 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -112,13 +106,13 @@ class TestPhiMultipack(unittest.TestCase):
"eval_steps": 3,
"save_steps": 4,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,14 +2,11 @@
E2E tests for resuming training
"""
import logging
import os
import re
import subprocess
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestResumeLlama:
"""
@@ -64,6 +58,7 @@ class TestResumeLlama:
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -72,8 +67,7 @@ class TestResumeLlama:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
@@ -83,7 +77,6 @@ class TestResumeLlama:
}
)
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,480 +0,0 @@
"""Tests for sequence parallelism functionality."""
# pylint: disable=redefined-outer-name,unused-argument
import functools
import sys
from unittest.mock import MagicMock, patch
import pytest
import torch
from accelerate.state import PartialState
from axolotl.monkeypatch.ring_attn import (
get_ring_attn_group,
register_ring_attn,
set_ring_attn_group,
)
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.enums import RingAttnFunc
from axolotl.utils.schemas.trl import TRLConfig
@pytest.fixture
def partial_state():
"""Create a real PartialState instance for testing."""
state = PartialState()
return state
@pytest.fixture(name="cfg")
def fixture_cfg():
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
return cfg
@pytest.fixture
def sequence_parallel_batch():
"""Create a test batch for sequence parallelism tests."""
batch_size = 1
seq_len = 8
# Create test tensors
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
labels = input_ids.clone()
# Create test batch
batch = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"labels": labels,
}
return batch
class TestRingAttention:
"""Tests for the ring attention functionality."""
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_get_ring_attn_group_no_registration(
self, mock_world_size, mock_rank, partial_state
):
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
# Setup mocks
mock_world_size.return_value = 4
mock_rank.return_value = 0
# Verify that RuntimeError is raised when no group is registered
with pytest.raises(
RuntimeError, match="register_ring_attn\\(\\) not yet called"
):
get_ring_attn_group()
@patch("torch.distributed.new_group")
@patch("torch.distributed.get_rank")
@patch("torch.distributed.get_world_size")
def test_register_ring_attn(
self, mock_world_size, mock_rank, mock_new_group, partial_state
):
"""Test that ring attention groups are created correctly."""
# Setup mocks
mock_world_size.return_value = 8 # 8 GPUs total
mock_rank.return_value = 3 # GPU #3
mock_group = MagicMock()
mock_new_group.return_value = mock_group
# Call register_ring_attn with size 4
register_ring_attn(
sequence_parallel_degree=4,
heads_k_stride=1,
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
)
# Verify the number of calls without examining the arguments
assert mock_new_group.call_count == 2
# Verify that new_group was called
mock_new_group.assert_called()
# Clean up
set_ring_attn_group(None)
class TestConfigValidation:
"""Tests for validating sequence parallelism configurations."""
@pytest.fixture(autouse=True)
def setup_mocks(self, monkeypatch):
"""Set up mocks for all tests in this class."""
# Mock the ring_flash_attn module
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
@pytest.fixture
def base_cfg(self):
"""Create a base configuration for testing."""
return DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-3,
"output_dir": "./model-out",
"sequence_len": 512,
"special_tokens": {"pad_token": "<|endoftext|>"},
}
)
@pytest.mark.parametrize(
"config_updates, expected_values, should_pass, error_msg",
[
# Valid configuration
(
{"sequence_parallel_degree": 2, "flash_attention": True},
{"sequence_parallel_degree": 2, "flash_attention": True},
True,
None,
),
# Default sequence_parallel_degree
({}, {"sequence_parallel_degree": 1}, True, None),
# Invalid: sequence_parallel_degree > 1 without flash_attention
(
{"sequence_parallel_degree": 2, "flash_attention": False},
None,
False,
"flash_attention: true must be set",
),
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"pad_to_sequence_len": True,
},
None,
False,
"micro_batch_size must be set to 1",
),
# Valid: Basic GRPO config
(
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
{
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": TRLConfig(use_liger_loss=True),
},
True,
"GRPO + SP + Liger not currently supported",
),
# Invalid: GRPO config with Liger loss
(
{
"rl": "grpo",
"sequence_parallel_degree": 2,
"flash_attention": True,
"micro_batch_size": 2,
"trl": {"use_liger_loss": True},
},
None,
False,
"GRPO + SP + Liger not currently supported",
),
],
ids=[
"valid_config",
"default_sp_degree",
"without_flash_attention",
"sample_packing_with_large_batch",
"valid_grpo",
"grpo_with_liger_loss",
],
)
def test_sequence_parallel_config_validation(
self, base_cfg, config_updates, expected_values, should_pass, error_msg
):
"""Test various sequence parallelism configuration scenarios."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg
cfg.update(config_updates)
if should_pass:
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check expected values
for key, value in expected_values.items():
assert getattr(config, key) == value
else:
# Should raise exception
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
assert error_msg in str(excinfo.value)
@pytest.mark.parametrize(
"ring_attn_func, sample_packing, expected_func",
[
(None, True, RingAttnFunc.VARLEN_LLAMA3),
(None, False, RingAttnFunc.BATCH_RING),
],
ids=["default_with_sample_packing", "default_without_sample_packing"],
)
def test_ring_attn_func_validation(
self, base_cfg, ring_attn_func, sample_packing, expected_func
):
"""Test ring_attn_func validation and defaults."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Apply updates to base config
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"sample_packing": sample_packing,
}
if ring_attn_func is not None:
cfg["ring_attn_func"] = ring_attn_func
# Should validate without errors
config = AxolotlInputConfig(**cfg)
# Check ring_attn_func value
assert config.ring_attn_func.value == expected_func
def test_invalid_ring_attn_func(self, base_cfg):
"""Test that an invalid ring_attn_func is rejected."""
from axolotl.utils.schemas.config import AxolotlInputConfig
# Invalid configuration with invalid ring_attn_func
cfg = base_cfg | {
"sequence_parallel_degree": 2,
"flash_attention": True,
"ring_attn_func": "INVALID_FUNC",
}
# Should raise ValidationError
with pytest.raises(ValueError) as excinfo:
AxolotlInputConfig(**cfg)
# Verify error message
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
class TestApplySequenceParallelism:
"""Tests for the apply_sequence_parallelism function."""
@pytest.fixture(autouse=True)
def mock_distributed(self, monkeypatch):
"""Mock torch.distributed functions for testing."""
# Mock is_initialized to return True
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
# Mock get_rank to return 0 by default
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
# Mock get_world_size to return 2 by default
monkeypatch.setattr(
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
)
# Mock the process group
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
MagicMock,
)
# Mock update_ring_attn_params
monkeypatch.setattr(
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
lambda **kwargs: None,
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test that function returns original batch when world size is 1."""
mock_get_ring_attn_group.return_value = 0
result, _, _ = apply_sequence_parallelism(
batch=sequence_parallel_batch,
local_rank=0,
local_world_size=1,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Should return the original batch unchanged
assert result == sequence_parallel_batch
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Check that sequence dimension was sharded correctly
assert result["input_ids"].shape[1] == seq_len // 2
assert result["attention_mask"].shape[1] == seq_len // 2
# Verify content: rank 0 should get the first half of the sequence
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
assert torch.equal(
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
seq_len = batch["input_ids"].size(1)
original_input_ids = batch["input_ids"].clone()
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=1,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verify content: rank 1 should get the second half of the sequence
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
# TODO(djsaunde): add back once implemented.
# def test_batch_zigzag(self, sequence_parallel_batch):
# """Test BATCH_ZIGZAG sharding pattern."""
# batch = sequence_parallel_batch
# original_input_ids = batch["input_ids"].clone()
# seq_len = batch["input_ids"].size(1)
# # Test rank 0
# result_rank0 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=0,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Test rank 1
# result_rank1 = apply_sequence_parallelism(
# batch={k: v.clone() for k, v in batch.items()},
# local_rank=1,
# local_world_size=2,
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
# )
# # Checks for both ranks
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
# # Rank 0 should get chunks [0, 1] and [6, 7]
# # Rank 1 should get chunks [2, 3] and [4, 5]
# if seq_len == 8:
# # Create expected tensors for comparison
# rank0_expected = torch.cat(
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
# )
# rank1_expected = torch.cat(
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
# )
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
def test_partial_application(
self, mock_get_ring_attn_group, sequence_parallel_batch
):
"""Test that we can create a partially applied version of the function."""
mock_get_ring_attn_group.return_value = 0
batch = sequence_parallel_batch
original_input_ids = batch["input_ids"].clone()
# Create a partially applied function
rank0_ring_parallel = functools.partial(
apply_sequence_parallelism,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Use the partially applied function
result, _, _ = rank0_ring_parallel(batch=batch)
# Verify it works as expected
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
assert torch.equal(
result["input_ids"],
original_input_ids[:, : original_input_ids.shape[1] // 2],
)
def test_missing_position_ids(self, sequence_parallel_batch):
"""Test handling of batch without position_ids."""
# Create a batch without position_ids
batch = {
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
}
original_input_ids = batch["input_ids"].clone()
# This should run without error even though position_ids is missing
result, _, _ = apply_sequence_parallelism(
batch=batch,
local_rank=0,
local_world_size=2,
gradient_accumulation_steps=1,
ring_attn_func=RingAttnFunc.BATCH_RING,
)
# Verification should pass
assert "position_ids" in result
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2

View File

@@ -2,12 +2,8 @@
e2e tests for unsloth qlora
"""
import logging
import os
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -15,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(
@@ -69,19 +62,19 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
@@ -120,19 +113,19 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)
@pytest.mark.parametrize(
@@ -176,17 +169,17 @@ class TestUnslothQLoRA:
"lr_scheduler": "cosine",
"use_tensorboard": True,
"fp16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -2,13 +2,10 @@
E2E tests for packed training w/ flex attention
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedFlex(unittest.TestCase):
"""
@@ -55,6 +49,7 @@ class TestPackedFlex(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -64,11 +59,10 @@ class TestPackedFlex(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
)

View File

@@ -2,12 +2,9 @@
E2E tests for relora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -15,9 +12,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestReLoraLlama(unittest.TestCase):
"""
@@ -40,9 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 50,
"relora_warmup_steps": 10,
"relora_anneal_steps": 10,
"relora": True,
"jagged_restart_steps": 50,
"jagged_restart_warmup_steps": 10,
"jagged_restart_anneal_steps": 10,
"relora_prune_ratio": 0.9,
"relora_cpu_offload": True,
"val_set_size": 0.0,
@@ -71,13 +66,13 @@ class TestReLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"save_safetensors": True,
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)

View File

@@ -0,0 +1,83 @@
"""
E2E tests for activation offloading
"""
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists
# pylint: disable=duplicate-code
class TestActivationOffloading:
"""
E2E test cases for activation offloading
"""
@pytest.mark.parametrize(
"adapter",
["lora", "qlora", None],
)
def test_activation_offloading(
self,
temp_dir,
adapter,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
},
"datasets": [
{
"chat_template": "chatml",
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"split": "train[:10%]",
"field_messages": "conversations",
"message_field_role": "from",
"message_field_content": "value",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": "auto",
"save_safetensors": True,
"gradient_checkpointing": True,
"activation_offloading": True,
"save_first_step": False,
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
}
)
if adapter == "lora":
cfg["adapter"] = "lora"
if adapter == "qlora":
cfg["adapter"] = "qlora"
cfg["load_in_4bit"] = True
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,10 @@
E2E tests for deepseekv3
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestDeepseekV3:
"""
@@ -73,12 +67,12 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -123,12 +117,12 @@ class TestDeepseekV3:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -1,9 +1,5 @@
"""
E2E tests for lora llama
"""
"""E2E tests for lora llama"""
import logging
import os
import unittest
from pathlib import Path
@@ -17,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestDPOLlamaLora(unittest.TestCase):
"""
@@ -63,6 +56,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -112,6 +106,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -161,6 +156,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -210,6 +206,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -258,6 +255,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -309,6 +307,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)
@@ -377,6 +376,7 @@ class TestDPOLlamaLora(unittest.TestCase):
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
"save_first_step": False,
}
)

View File

@@ -2,11 +2,8 @@
E2E tests for llama pretrain
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestEmbeddingsLrScale(unittest.TestCase):
"""
@@ -54,13 +48,13 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -100,12 +94,12 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,6 +1,5 @@
"""E2E smoke test for evaluate CLI command"""
import os
from pathlib import Path
import yaml
@@ -9,8 +8,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
os.environ["WANDB_DISABLED"] = "true"
class TestE2eEvaluate:
"""Test cases for evaluate CLI"""
@@ -39,6 +36,7 @@ class TestE2eEvaluate:
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_first_step": False,
}
)

View File

@@ -2,13 +2,10 @@
E2E tests for falcon
"""
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFalcon(unittest.TestCase):
"""
@@ -66,13 +60,13 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -122,13 +116,13 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -164,13 +158,13 @@ class TestFalcon(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,21 +2,15 @@
E2E tests for gemma2
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma2:
"""
@@ -74,8 +68,7 @@ class TestGemma2:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -126,8 +119,7 @@ class TestGemma2:
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -2,21 +2,15 @@
E2E tests for gemma3_text
"""
import logging
import os
from pathlib import Path
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestGemma3Text:
"""
@@ -69,12 +63,12 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@@ -120,12 +114,12 @@ class TestGemma3Text:
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -11,11 +11,11 @@ class TestImports(unittest.TestCase):
"""
def test_import_causal_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
def test_import_rl_trainer(self):
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)

View File

@@ -2,10 +2,6 @@
E2E tests for llama
"""
import logging
import os
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -13,9 +9,6 @@ from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLlama:
"""
@@ -52,13 +45,13 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -100,13 +93,13 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -145,13 +138,13 @@ class TestLlama:
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -186,13 +179,13 @@ class TestLlama:
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,13 +1,7 @@
"""
E2E tests for llama pretrain
"""
import logging
import os
"""E2E tests for llama pretrain"""
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -15,27 +9,19 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama:
"""
Test case for Llama models w pretraining
"""
"""Test case for Llama models w pretraining"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
@pytest.mark.parametrize(
"pretrain_multipack_attn",
[True, False],
("sample_packing", "pretrain_multipack_attn"),
[
(False, False),
(True, True),
(True, False),
],
)
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
if not sample_packing and pretrain_multipack_attn:
return
# pylint: disable=duplicate-code
cfg = DictDefault(
{
@@ -67,22 +53,22 @@ class TestPretrainLlama:
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
loss_threshold = 3.5
loss_threshold = 3.6
if sample_packing and not pretrain_multipack_attn:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
"Train Loss (%s) is too high",
)

View File

@@ -2,11 +2,8 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLlamaVision(unittest.TestCase):
"""
@@ -38,7 +32,7 @@ class TestLlamaVision(unittest.TestCase):
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
@@ -60,13 +54,13 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -86,7 +80,7 @@ class TestLlamaVision(unittest.TestCase):
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
"val_set_size": 0,
"chat_template": "llama3_2_vision",
"datasets": [
@@ -107,12 +101,12 @@ class TestLlamaVision(unittest.TestCase):
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -52,6 +52,8 @@ class TestLoadModelUtils:
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"tensor_parallel_size": 1,
"context_parallel_size": 1,
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init

View File

@@ -2,11 +2,8 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLoraLlama(unittest.TestCase):
"""
@@ -55,13 +49,13 @@ class TestLoraLlama(unittest.TestCase):
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,10 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="skipping until upstreamed into transformers")
class TestMamba(unittest.TestCase):
@@ -57,13 +51,13 @@ class TestMamba(unittest.TestCase):
"save_steps": 10,
"eval_steps": None,
"save_safetensors": False,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,10 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMistral(unittest.TestCase):
"""
@@ -61,13 +55,13 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -102,6 +96,7 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -111,8 +106,7 @@ class TestMistral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,14 +2,11 @@
E2E tests for mixtral
"""
import logging
import os
import unittest
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestMixtral(unittest.TestCase):
"""
@@ -67,13 +61,13 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -123,13 +117,13 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -178,6 +172,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -187,8 +182,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -237,6 +231,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
@@ -246,8 +241,7 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
@@ -283,6 +277,7 @@ class TestMixtral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -292,8 +287,7 @@ class TestMixtral(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,20 +2,20 @@
E2E tests for custom optimizers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
from .utils import (
check_model_output_exists,
require_torch_2_5_1,
require_torch_2_6_0,
require_torch_2_7_0,
with_temp_dir,
)
class TestCustomOptimizers(unittest.TestCase):
@@ -56,13 +56,13 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "optimi_adamw",
"max_steps": 5,
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -102,13 +102,13 @@ class TestCustomOptimizers(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adopt_adamw",
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -149,18 +149,61 @@ class TestCustomOptimizers(unittest.TestCase):
"optimizer": "muon",
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"val_set_size": 0.0,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "dion",
"dion_lr": 0.01,
"dion_momentum": 0.95,
"lr_scheduler": "cosine",
"weight_decay": 0.01,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
@@ -188,19 +231,20 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "constant",
"save_safetensors": True,
"max_steps": 10,
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@require_torch_2_6_0
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -236,13 +280,13 @@ class TestCustomOptimizers(unittest.TestCase):
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -2,13 +2,10 @@
E2E tests for packed training
"""
import logging
import os
import unittest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPackedLlama(unittest.TestCase):
"""
@@ -54,6 +48,7 @@ class TestPackedLlama(unittest.TestCase):
"lr_scheduler": "cosine",
"max_steps": 5,
"use_tensorboard": True,
"save_first_step": False,
}
)
if is_torch_bf16_gpu_available():
@@ -63,11 +58,10 @@ class TestPackedLlama(unittest.TestCase):
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
)

View File

@@ -2,11 +2,8 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPhi(unittest.TestCase):
"""
@@ -59,12 +53,12 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@@ -109,12 +103,12 @@ class TestPhi(unittest.TestCase):
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,58 @@
"""E2E Test the preprocess cli"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
class TestPreprocess:
"""test cases for preprocess"""
def test_w_deepspeed(self, temp_dir):
"""make sure preproces doesn't choke when using deepspeed in the config"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": "auto",
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
"dataset_prepared_path": temp_dir + "/last_run_prepared",
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"preprocess",
str(Path(temp_dir) / "config.yaml"),
]
)
assert (Path(temp_dir) / "last_run_prepared").exists()

View File

@@ -2,11 +2,8 @@
E2E tests for process reward model w/ lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestProcessRewardSmolLM2(unittest.TestCase):
"""
@@ -55,12 +49,12 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"use_tensorboard": True,
"special_tokens": {"pad_token": "<|endoftext|>"},
"seed": 42,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(

113
tests/e2e/test_profiler.py Normal file
View File

@@ -0,0 +1,113 @@
"""
e2e gpu test for the pytorch profiler callback
"""
from pathlib import Path
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="profiler_base_cfg")
def fixture_profiler_base_cfg():
cfg = DictDefault(
base_model="HuggingFaceTB/SmolLM2-135M",
tokenizer_type="AutoTokenizer",
sequence_len=1024,
load_in_8bit=True,
adapter="lora",
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_target_linear=True,
val_set_size=0.02,
special_tokens={"pad_token": "<|endoftext|>"},
datasets=[
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
num_epochs=1,
micro_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=0.00001,
optimizer="adamw_torch_fused",
lr_scheduler="cosine",
)
return cfg
class TestProfiler:
"""
test cases for the pytorch profiler callback
"""
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=1,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
@pytest.mark.parametrize(
"profiler_steps_start",
[3, 5],
)
def test_profiler_saves_past_end(
self, profiler_base_cfg, temp_dir, profiler_steps_start
):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=profiler_steps_start,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "snapshot.pickle").exists()
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
cfg = profiler_base_cfg | DictDefault(
output_dir=temp_dir,
max_steps=5,
profiler_steps=3,
profiler_steps_start=6,
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
assert not (Path(temp_dir) / "snapshot.pickle").exists()

135
tests/e2e/test_qat.py Normal file
View File

@@ -0,0 +1,135 @@
"""
E2E tests for QAT
"""
from pathlib import Path
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestQATLlama:
"""
Test case for QAT Llama models
"""
def test_qat(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mlabonne/FineTome-100k",
"type": "chat_template",
"field_messages": "conversations",
"message_property_mappings": {
"role": "from",
"content": "value",
},
"drop_system_message": True,
"split": "train[:1%]",
},
],
"chat_template": "chatml",
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"group_size": 8,
},
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_safetensors": True,
"bf16": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
def test_qat_dpo(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"use_tensorboard": True,
"bf16": True,
"qat": {
"quantize_embedding": True,
"activation_dtype": "int8",
"weight_dtype": "int8",
"group_size": 8,
},
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_preference_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
loss_threshold = 2.3
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss (%s) is too high",
)

View File

@@ -0,0 +1,350 @@
"""
Tests for axolotl.utils.quantization
"""
import pytest
import torch
from torch import nn
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from torchao.quantization.quant_api import (
Int4DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
UIntXWeightOnlyConfig,
)
from transformers import AutoModelForCausalLM
from transformers.trainer_callback import TrainerState
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.quantization import (
convert_qat_model_for_ptq,
get_ptq_config,
prepare_model_for_qat,
quantize_model_for_ptq,
)
from axolotl.utils.schemas.enums import TorchIntDType
from axolotl.utils.schemas.quantization import QATConfig
from tests.e2e.utils import require_torch_2_6_0
@pytest.fixture()
def model():
dummy_model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceTB/SmolLM2-135M",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
with torch.device(dummy_model.device):
dummy_model.model.embed_tokens = torch.nn.Embedding(
dummy_model.model.embed_tokens.weight.shape[0],
dummy_model.model.embed_tokens.weight.shape[1],
dtype=dummy_model.model.embed_tokens.weight.dtype,
)
return dummy_model
ptq_config_test_cases = [
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
(
TorchIntDType.uint4,
None,
None,
UIntXWeightOnlyConfig,
{"dtype": torch.uint4, "group_size": None},
),
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
(
TorchIntDType.int4,
TorchIntDType.int4,
None,
Int4DynamicActivationInt4WeightConfig,
{},
),
(
TorchIntDType.int8,
TorchIntDType.int8,
None,
Int8DynamicActivationInt8WeightConfig,
{},
),
]
ptq_test_cases = [
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
(TorchIntDType.int8, None, 8, False, None),
(TorchIntDType.int4, None, 4, True, None),
(TorchIntDType.uint4, None, 8, False, None),
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
(TorchIntDType.int8, None, None, False, ValueError),
(TorchIntDType.int4, None, None, False, ValueError),
]
class TestQuantization:
"""
Test quantization utilities
"""
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
ptq_config_test_cases,
)
@require_torch_2_6_0
def test_get_ptq_config(
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
):
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
assert isinstance(config, expected_type)
for param_name, param_value in expected_params.items():
if isinstance(param_value, (PerAxis, PerGroup)):
if isinstance(param_value, PerAxis):
assert isinstance(getattr(config, param_name), PerAxis)
assert getattr(config, param_name).axis == param_value.axis
else:
assert isinstance(getattr(config, param_name), PerGroup)
assert (
getattr(config, param_name).group_size == param_value.group_size
)
else:
assert getattr(config, param_name) == param_value
@pytest.mark.parametrize(
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
)
@pytest.mark.parametrize(
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
)
@pytest.mark.parametrize("group_size", [4, 8])
@pytest.mark.parametrize("quantize_embedding", [False, True])
@require_torch_2_6_0
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
): # pylint: disable=redefined-outer-name
prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
assert (
model.model.embed_tokens.weight_fake_quantizer.config.dtype
== weight_dtype.value
)
assert (
model.model.embed_tokens.weight_fake_quantizer.config.group_size
== group_size
)
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
assert isinstance(child, FakeQuantizedLinear)
assert hasattr(child, "weight_fake_quantizer")
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
assert child.weight_fake_quantizer.config.group_size == group_size
if activation_dtype:
assert hasattr(child, "activation_fake_quantizer")
assert (
child.activation_fake_quantizer.config.dtype
== activation_dtype.value
)
else:
assert child.activation_fake_quantizer is None
@pytest.mark.parametrize(
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
ptq_test_cases,
)
@require_torch_2_6_0
def test_quantize_model_for_ptq(
self,
model,
weight_dtype,
activation_dtype,
group_size,
quantize_embedding,
expected_exception,
): # pylint: disable=redefined-outer-name
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
model,
weight_dtype,
group_size,
activation_dtype,
quantize_embedding,
)
else:
quantize_model_for_ptq(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
if quantize_embedding:
assert isinstance(
model.model.embed_tokens.weight, AffineQuantizedTensor
), "Embedding weight should be quantized"
for child in list(model.children()):
if isinstance(child, torch.nn.Linear):
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), "Linear weight should be quantized with activation quantization"
else:
assert isinstance(
child.weight, AffineQuantizedTensor
), "Linear weight should be quantized without activation quantization"
class TestQuantizationCallback:
"""
Test QATCallback
"""
@pytest.fixture()
def trainer_state(self):
return TrainerState(
global_step=0,
)
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=100,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should have been disabled
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
assert not model.lm_head.weight_fake_quantizer.enabled
trainer_state.global_step = 100
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should have been enabled
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps_is_none(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
fake_quant_after_n_steps=None,
)
prepare_model_for_qat(
model,
cfg.weight_dtype,
cfg.group_size,
cfg.activation_dtype,
cfg.quantize_embedding,
)
# ensure model has been quantized
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert isinstance(model.lm_head, FakeQuantizedLinear)
assert model.lm_head.weight_fake_quantizer.enabled
qat_callback = QATCallback(cfg)
# simulate first training step
qat_callback.on_step_begin(
args=None,
state=trainer_state,
control=None,
model=model,
)
# quantization should be enabled from the get-go
assert model.model.embed_tokens.weight_fake_quantizer.enabled
assert model.lm_head.weight_fake_quantizer.enabled
class TestConvertQATModelForPTQ:
"""
Test convert_qat_model_for_ptq
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(
self, model
): # pylint: disable=redefined-outer-name
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
group_size=8,
quantize_embedding=True,
)
# quantize model for qat
prepare_model_for_qat(
model,
config.weight_dtype,
config.group_size,
config.activation_dtype,
config.quantize_embedding,
)
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert isinstance(model.lm_head, FakeQuantizedLinear)
# apply conversion
convert_qat_model_for_ptq(
model,
quantize_embedding=config.quantize_embedding,
)
# ensure modules have been swapped out
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
assert not isinstance(model.lm_head, FakeQuantizedLinear)
# ensure weights have been quantized
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
assert isinstance(model.lm_head.weight, nn.Parameter)

View File

@@ -2,8 +2,6 @@
E2E tests for qwen
"""
import logging
import os
from pathlib import Path
import pytest
@@ -13,9 +11,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.qwen")
os.environ["WANDB_DISABLED"] = "true"
class TestE2eQwen:
"""
@@ -64,6 +59,7 @@ class TestE2eQwen:
"bf16": "auto",
"tf32": True,
"gradient_checkpointing": True,
"save_first_step": False,
}
)

View File

@@ -2,11 +2,8 @@
E2E tests for reward model lora llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestRewardModelLoraSmolLM2(unittest.TestCase):
"""
@@ -64,15 +58,15 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"gradient_checkpointing": True,
"warmup_ratio": 0.1,
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,102 @@
"""
E2E tests for relora llama
"""
import unittest
from pathlib import Path
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestSaveFirstStepCallback(unittest.TestCase):
"""Test cases for save_first_step callback config."""
@with_temp_dir
def test_save_first_step(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 512,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
@with_temp_dir
def test_no_save_first_step(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 512,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": True,
"bf16": True,
"save_safetensors": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
with pytest.raises(AssertionError):
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)

View File

@@ -2,11 +2,8 @@
E2E tests for custom schedulers using Llama
"""
import logging
import os
import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestCustomSchedulers(unittest.TestCase):
"""
@@ -57,13 +51,13 @@ class TestCustomSchedulers(unittest.TestCase):
"lr_scheduler": "rex",
"warmup_steps": 5,
"cosine_min_lr_ratio": 0.05,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -10,8 +10,6 @@ from functools import wraps
from pathlib import Path
import torch
# from importlib.metadata import version
from packaging import version
from tbparse import SummaryReader
@@ -79,6 +77,18 @@ def require_torch_2_6_0(test_case):
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
def require_torch_2_7_0(test_case):
"""
Decorator marking a test that requires torch >= 2.7.0
"""
def is_min_2_7_0():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.7.0")
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
def require_torch_lt_2_6_0(test_case):
"""
Decorator marking a test that requires torch < 2.6.0
@@ -132,6 +142,10 @@ def is_hopper():
return compute_capability == (9, 0)
def require_hopper(test_case):
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
) -> None:

View File

@@ -2,8 +2,6 @@
config validation tests for swiglu args
"""
# pylint: disable=duplicate-code
import logging
from typing import Optional
import pytest
@@ -12,6 +10,7 @@ from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
return DictDefault(
@@ -41,7 +40,7 @@ class TestValidation:
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
caplog.set_level("WARNING")
self._caplog = caplog
def test_deprecated_swiglu(self, minimal_liger_cfg):
@@ -52,9 +51,7 @@ class TestValidation:
| minimal_liger_cfg
)
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
with self._caplog.at_level("WARNING", logger="axolotl.integrations.liger.args"):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg)
# TODO this test is brittle in CI

View File

@@ -0,0 +1,26 @@
"""
Unit tests for trainer accelerator args monkeypatch
"""
import unittest
from axolotl.monkeypatch.trainer_accelerator_args import (
check_create_accelerate_code_is_patchable,
)
class TestTrainerAcceleratorArgs(unittest.TestCase):
"""
Unit test class for trainer accelerator args monkeypatch
"""
def test_check_create_accelerate_code_is_patchable(self):
"""
Test that the upstream transformers code is still patchable.
This will fail if the patched code changes upstream.
"""
assert check_create_accelerate_code_is_patchable()
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,28 @@
"""Unit tests for trainer loss calc monkeypatch."""
import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)
class TestTrainerLossCalc(unittest.TestCase):
"""
Unit test class for trainer loss calc monkeypatch
"""
def test_trainer_loss_calc_is_patchable(self):
"""
Test that the upstream transformers code is still patchable. This will fail if
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,7 +1,6 @@
# pylint: disable=too-many-lines
"""Module for testing the validation module"""
import logging
import os
import warnings
from typing import Optional
@@ -80,7 +79,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(test_cfg)
assert (
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
@@ -218,7 +217,7 @@ class TestValidation(BaseValidation):
}
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
@@ -513,7 +512,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"BetterTransformers probably doesn't work with PEFT adapters"
@@ -531,7 +530,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"probably set bfloat16 or float16" in record.message
@@ -577,7 +576,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -595,7 +594,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"adamw hyperparameters found, but no adamw optimizer set"
@@ -654,7 +653,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert any(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
@@ -673,7 +672,7 @@ class TestValidation(BaseValidation):
)
| minimal_cfg
)
with self._caplog.at_level(logging.INFO):
with self._caplog.at_level("INFO"):
cfg = validate_config(cfg)
assert any(
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
@@ -693,7 +692,7 @@ class TestValidation(BaseValidation):
"bf16": True,
"capabilities": {"bf16": False},
"env_capabilities": {
"torch_version": "2.5.1",
"torch_version": "2.6.0",
},
}
)
@@ -1109,7 +1108,7 @@ class TestValidation(BaseValidation):
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1118,7 +1117,7 @@ class TestValidation(BaseValidation):
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 1
@@ -1128,7 +1127,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
@@ -1138,28 +1137,28 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_none(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_dpo_beta_deprecation(self, minimal_cfg):
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg["rl_beta"] == 0.2
assert new_cfg["dpo_beta"] is None
@@ -1175,7 +1174,7 @@ class TestValidation(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert new_cfg.eval_strategy == "steps"
assert (
@@ -1203,7 +1202,7 @@ class TestValidation(BaseValidation):
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)
env_capabilities = {"torch_version": "2.5.1"}
env_capabilities = {"torch_version": "2.6.0"}
capabilities = {"bf16": False}
_ = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
@@ -1245,7 +1244,7 @@ class TestTorchCompileValidation(BaseValidation):
| minimal_cfg
)
env_capabilities = {"torch_version": "2.5.1"}
env_capabilities = {"torch_version": "2.6.0"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
@@ -1455,7 +1454,7 @@ class TestValidationWandb(BaseValidation):
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level("WARNING"):
new_cfg = validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
@@ -1505,7 +1504,6 @@ class TestValidationWandb(BaseValidation):
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_NAME", None)
@@ -1514,16 +1512,12 @@ class TestValidationWandb(BaseValidation):
os.environ.pop("WANDB_MODE", None)
os.environ.pop("WANDB_WATCH", None)
os.environ.pop("WANDB_LOG_MODEL", None)
os.environ.pop("WANDB_DISABLED", None)
def test_wandb_set_disabled(self, minimal_cfg):
cfg = DictDefault({}) | minimal_cfg
new_cfg = validate_config(cfg)
setup_wandb_env_vars(new_cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
assert new_cfg.use_wandb is None
cfg = (
DictDefault(
@@ -1535,13 +1529,10 @@ class TestValidationWandb(BaseValidation):
)
new_cfg = validate_config(cfg)
setup_wandb_env_vars(new_cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"
assert new_cfg.use_wandb is True
os.environ.pop("WANDB_PROJECT", None)
os.environ.pop("WANDB_DISABLED", None)
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
@@ -1699,3 +1690,18 @@ class TestValidationMLflow(BaseValidation):
assert new_cfg.use_mlflow is True
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
class TestDataloaderValidation(BaseValidation):
"""
tests for dataloader_* sane defaults
"""
def test_dataloader_auto_defaults(self, minimal_cfg):
cfg = minimal_cfg
new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"})
assert new_cfg.dataloader_num_workers == 8
assert new_cfg.dataloader_pin_memory is True
assert new_cfg.dataloader_prefetch_factor == 256

View File

@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
return tokenizer
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
def fixture_phi4_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
return tokenizer
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
def fixture_gemma2_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
@@ -150,6 +156,30 @@ def fixture_gemma2_tokenizer():
return tokenizer
@pytest.fixture(name="magistral_tokenizer")
def fixture_magistral_tokenizer():
from axolotl.utils.mistral import HFMistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506")
return tokenizer
@pytest.fixture(name="devstral_tokenizer")
def fixture_devstral_tokenizer():
from axolotl.utils.mistral import HFMistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505")
return tokenizer
@pytest.fixture(name="devstral_1_1_tokenizer")
def fixture_devstral_1_1_tokenizer():
from axolotl.utils.mistral import HFMistralTokenizer
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2507")
return tokenizer
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'

View File

@@ -3,14 +3,13 @@ tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import logging
import unittest
from axolotl.prompt_strategies.messages.chat import load
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__, log_level="DEBUG")
class TestMessagesChatLlama3:

View File

@@ -0,0 +1,75 @@
"""
Tests for chat template prompt strategy with schema unification for none fields
"""
import json
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
from axolotl.prompt_strategies.chat_template import StrategyLoader
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="messages_w_tools")
def fixture_messages_w_tools():
jsons = """
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
""".strip().split(
"\n"
)
rows = [json.loads(row) for row in jsons]
return Dataset.from_list(rows)
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@pytest.fixture(name="qwen3_prompt_strategy")
def qwen3_chat_template_strategy(qwen3_tokenizer):
cfg = DictDefault(
sequence_len=2048,
chat_template="qwen3",
eot_tokens=["<|im_end|>"],
)
ds_cfg = DictDefault(
type="chat_template",
)
load = StrategyLoader()
strat = load(qwen3_tokenizer, cfg, ds_cfg)
return strat
class TestSchemaUnification:
"""
Test class on handling null fields for tool calling
"""
def test_schema_unification_single_prompt(
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
):
for row in messages_w_tools:
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
assert '"message": null' not in tool_call
assert '"theta": null' not in tool_call
def test_schema_unification_batched(
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
):
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
for row in rows:
decoded = qwen3_tokenizer.decode(row["input_ids"])
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
assert '"message": null' not in tool_call
assert '"theta": null' not in tool_call

View File

@@ -2,7 +2,6 @@
tests for chat_template prompt strategy
"""
import logging
import unittest
from axolotl.prompt_strategies.chat_template import (
@@ -13,9 +12,9 @@ from axolotl.prompt_strategies.chat_template import (
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
class TestAssistantChatTemplateLlama3:

View File

@@ -4,7 +4,6 @@ tests for chat_template prompt strategy
# pylint: disable=too-many-lines
import logging
from copy import deepcopy
import pytest
@@ -18,11 +17,11 @@ from axolotl.prompt_strategies.chat_template import (
)
from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.logging import get_logger
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token"
PARAMETRIZE_PARAMS = [
@@ -34,15 +33,14 @@ PARAMETRIZE_PARAMS = [
"mistralv03_tokenizer_chat_template_jinja",
"[/INST]",
),
# TODO: temporarily skip gemma due to gemma3 template
# Re-enable on new chat_template implementation for perf
# (
# "gemma2_tokenizer",
# "jinja",
# "gemma2_tokenizer_chat_template_jinja",
# "<end_of_turn>",
# ),
(
"gemma2_tokenizer",
"jinja",
"gemma2_tokenizer_chat_template_jinja",
"<end_of_turn>",
),
("phi35_tokenizer", "phi_35", None, "<|end|>"),
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
]
@@ -96,11 +94,7 @@ class TestChatTemplateConfigurations:
if (
turn_idx == 0
and turn.get("from") in ["system", "context"]
and (
"mistral" in tokenizer.name_or_path.lower()
or "gemma"
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
)
and ("mistral" in tokenizer.name_or_path.lower())
):
assert (
start_idx == -1 and end_idx == -1
@@ -936,36 +930,14 @@ class TestChatTemplateConfigurations:
"messages",
)
if chat_template == "llama3":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "chatml":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "phi_35":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
# Special case for Mistral with additional tool variables
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
# Most chat templates use the standard role and content variables
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
):
expected_variables = {"role", "content"}
else:
LOG.warning(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
@@ -974,6 +946,12 @@ class TestChatTemplateConfigurations:
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
assert variables == expected_variables, (
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
def test_eot_tokens_conflict_with_eos_token(
self,
tokenizer,
@@ -1281,3 +1259,162 @@ class TestChatTemplateConfigurations:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
class TestChatTemplateToolCalling:
"""
Test class for tool calling functionality with chat templates.
"""
def test_tool_calling_with_llama4_template(
self,
llama3_tokenizer,
):
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
# Create tool calling dataset
tool_calling_dataset = [
{
"tools": [
{
"type": "function",
"function": {
"name": "xml_escape",
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
"parameters": {
"type": "object",
"properties": {
"s": {
"type": "string",
"description": "The input string to be XML-escaped.",
}
},
"required": ["s"],
},
},
},
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Can you help me find multiples of 5 that are less than 20?",
},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 5,
"limit": 20,
},
},
}
],
},
{"role": "tool", "name": "multiples", "content": "5,10,15"},
{
"role": "assistant",
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
},
],
}
]
# Setup tokenizer with llama4 chat template
tokenizer = deepcopy(llama3_tokenizer)
# Add EOS token to the tokenizer
eot_token = "<|eot_id|>"
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_template=get_chat_template("llama4"),
message_property_mappings={"role": "role", "content": "content"},
field_messages="messages",
field_tools="tools",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
eot_tokens=[eot_token],
)
res = strategy.tokenize_prompt(tool_calling_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
# Verify that the input_ids contain expected tokens
assert len(input_ids) > 0, "Input IDs should not be empty"
assert len(labels) == len(input_ids), "Labels should match input_ids length"
# Decode the full conversation to verify structure
decoded_conversation = tokenizer.decode(input_ids)
# Verify tool calling structure is present in the decoded conversation
assert (
'"type": "function",' in decoded_conversation
), "Tool type function should be in conversation"
assert (
'"name": "multiples",' in decoded_conversation
), "Tool function name should be in conversation"
assert (
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
in decoded_conversation
), "Assistant tool call should be in conversation"
assert (
"<|header_start|>ipython<|header_end|>" in decoded_conversation
), "IPython header should be in conversation"
assert (
'"5,10,15"' in decoded_conversation
), "Tool response should be in conversation"
# Get conversation turns to verify labeling
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
tools = strategy._get_tools( # pylint: disable=protected-access
tool_calling_dataset[0]
)
# Check that assistant responses are properly labeled
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
if turn["role"] == "assistant":
start_idx, end_idx = strategy.find_turn(
turns=turns, turn_idx=i, tools=tools
)
assert (
start_idx != -1 and end_idx != -1
), f"Assistant turn {i} should be found"
# Verify that assistant responses have proper labels
turn_labels = labels[start_idx:end_idx]
assert all(
label != IGNORE_TOKEN_ID for label in turn_labels
), f"Assistant turn {i} should be unmasked"

View File

@@ -0,0 +1,851 @@
"""Test chat templates for mistral-common wrapper tokenizer"""
import unittest
from typing import TYPE_CHECKING
import pytest
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from axolotl.utils.mistral import HFMistralTokenizer
# fmt: off
@pytest.mark.parametrize(
("tokenizer_str", "assistant_toolcall_ids", "tool_result_ids"),
(
("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
("devstral_1_1_tokenizer", (9, 44627, 3684, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2,), (7, 1049, 1044, 1050, 8)),
)
)
# fmt: on
def test_mistral_chat_template(
tokenizer_str: str,
assistant_toolcall_ids: tuple[int, ...],
tool_result_ids: tuple[int, ...],
request: pytest.FixtureRequest,
):
"""Test chat template with the Magistral/Devstral tokenizer"""
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
# check bos, eos, pad, unk are accessible properties
assert tokenizer.bos_token_id == 1
assert tokenizer.eos_token_id == 2
assert tokenizer.pad_token_id == 11
assert tokenizer.unk_token_id == 0
assert tokenizer.pad_token == "<pad>"
assert tokenizer.eos_token == "</s>"
assert tokenizer.bos_token == "<s>"
assert tokenizer.unk_token == "<unk>"
strategy = MistralStrategy(
MistralPrompter(
tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# test chat template masking without system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template masking with system prompt
res = strategy.tokenize_prompt(
{
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
]
}
)
assert res["input_ids"] == [
1, # bos
17, # [SYSTEM_PROMPT]
4568, # You
1584, # are
1261, # a
20351, # helpful
27089, # assistant
1046, # .
18, # [/SYSTEM_PROMPT]
3, # [INST]
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
4, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
assert res["labels"] == [
-100, # bos
-100, # [SYSTEM_PROMPT]
-100, # You
-100, # are
-100, # a
-100, # helpful
-100, # assistant
-100, # .
-100, # [/SYSTEM_PROMPT]
-100, # [INST]
-100, # Hello
-100, # ,
-100, # how
-100, # are
-100, # you
-100, # ?
-100, # [/INST]
1073, # I
4525, # 'm
6965, # doing
4824, # great
1044, # ,
15412, # thank
1636, # you
1033, # !
2, # </s>
]
# test chat template with tools
res = strategy.tokenize_prompt(
{
"tools": [
{
"type": "function",
"function": {
"name": "multiples",
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
"parameters": {
"type": "object",
"properties": {
"number": {
"type": "integer",
"description": "The number to find multiples of.",
},
"limit": {
"type": "integer",
"description": "The upper limit for the multiples.",
},
},
"required": ["number", "limit"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "multiples",
"arguments": {
"number": 16,
"limit": 2,
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "multiples",
"content": "1,2",
},
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
],
}
)
# fmt: off
assert res["input_ids"] == [
1, # bos
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
*assistant_toolcall_ids, # assistant tool calling
*tool_result_ids, # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
assert res["labels"] == [
-100, # bos
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
*assistant_toolcall_ids, # assistant tool calling
*([-100] * len(tool_result_ids)), # tool result
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
2 # eos
]
# fmt: on
# test chat template with tokenize=False
res = tokenizer.apply_chat_template(
[
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing great, thank you!"},
],
tokenize=False,
)
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
# test encode
res = tokenizer.encode("Hello, how are you?", add_special_tokens=True)
assert res == [
1, # bos
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
2, # eos
]
# test decode no skip special tokens
decoded_res = tokenizer.decode(res, skip_special_tokens=False)
assert decoded_res == "<s>Hello, how are you?</s>"
# test decode skip special tokens
decoded_res = tokenizer.decode(res, skip_special_tokens=True)
assert decoded_res == "Hello, how are you?"
# test encode no special tokens
res = tokenizer.encode("Hello, how are you?", add_special_tokens=False)
assert res == [
22177, # Hello
1044, # ,
2606, # how
1584, # are
1636, # you
1063, # ?
]
# test convert ids to tokens
res = tokenizer.convert_ids_to_tokens(res)
# spacing are needed as we are converting without decoding
assert res == ["Hello", ",", " how", " are", " you", "?"]
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"):
"""Test the MistralTokenizer pad method"""
from axolotl.utils.collators.core import IGNORE_INDEX
magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id
# Test padding with input_ids and labels only
features = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8], "labels": [9, 10]},
]
result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt")
# Check that input_ids are padded correctly
assert result["input_ids"].shape == (2, 3)
assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]
# Check that labels are padded correctly
assert result["labels"].shape == (2, 3)
assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]
# Check that attention_mask and position_ids are NOT created
assert "attention_mask" not in result
assert "position_ids" not in result
# Test padding with attention_mask
features_with_attention = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]},
{"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]},
]
result = magistral_tokenizer.pad(
features_with_attention, padding=True, return_tensors="pt"
)
# Check that attention_mask is padded correctly
assert result["attention_mask"].shape == (2, 3)
assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]]
# Test padding with position_ids
features_with_position = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]},
{"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]},
]
result = magistral_tokenizer.pad(
features_with_position, padding=True, return_tensors="pt"
)
# Check that position_ids are padded correctly (continuing sequence)
assert result["position_ids"].shape == (2, 3)
assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]]
# Test padding with all fields
features_all = [
{
"input_ids": [1, 2, 3],
"labels": [4, 5, 6],
"attention_mask": [1, 1, 1],
"position_ids": [0, 1, 2],
},
{
"input_ids": [7, 8],
"labels": [9, 10],
"attention_mask": [1, 1],
"position_ids": [0, 1],
},
]
result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt")
# All fields should be present and correctly padded
assert "input_ids" in result
assert "labels" in result
assert "attention_mask" in result
assert "position_ids" in result
# Test padding with all sequences same length
features_same_length = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
{"input_ids": [7, 8, 9], "labels": [10, 11, 12]},
]
result = magistral_tokenizer.pad(
features_same_length, padding=True, return_tensors="pt"
)
# Check match when no padding is needed
assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"]
assert result["labels"][0].tolist() == features_same_length[0]["labels"]
assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"]
assert result["labels"][1].tolist() == features_same_length[1]["labels"]
# Test padding with max_length parameter
result = magistral_tokenizer.pad(
features, padding="max_length", max_length=5, return_tensors="pt"
)
# Should pad to max_length
assert result["input_ids"].shape == (2, 5)
assert result["labels"].shape == (2, 5)
# Test numpy return type
result = magistral_tokenizer.pad(features, padding=True, return_tensors="np")
# Should return numpy arrays
import numpy as np
assert isinstance(result["input_ids"], np.ndarray)
assert isinstance(result["labels"], np.ndarray)
# Test unsupported field rejection
features_unsupported = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]},
]
with pytest.raises(NotImplementedError, match="unsupported_field"):
magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt")
# Test token_type_ids rejection
features_token_type = [
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]},
]
with pytest.raises(ValueError, match="token_type_ids is not supported"):
magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt")
def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
"""Test tool calling with the Magistral tokenizer"""
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
strategy = MistralStrategy(
MistralPrompter(
magistral_tokenizer,
chat_template=None,
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=magistral_tokenizer,
train_on_inputs=False,
train_on_eos="turn",
sequence_len=512,
roles_to_train=["assistant"],
)
# Test basic tool calling with single function
basic_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "What's the weather like in San Francisco?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {
"location": "San Francisco, CA",
},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "get_weather",
"content": "Sunny, 72°F",
},
{
"role": "assistant",
"content": "The weather in San Francisco is sunny and 72°F.",
},
],
}
res = strategy.tokenize_prompt(basic_tool_calling)
# Basic validation
assert "input_ids" in res
assert "labels" in res
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
# Decode and verify structure
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}</s>'
in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded
assert "The weather in San Francisco is sunny and 72°F.</s>" in decoded
# Test multiple tool calls in sequence
multi_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "add_numbers",
"description": "Add two numbers together",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
},
{
"type": "function",
"function": {
"name": "multiply_numbers",
"description": "Multiply two numbers",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "number", "description": "First number"},
"y": {"type": "number", "description": "Second number"},
},
"required": ["x", "y"],
},
},
},
],
"messages": [
{
"role": "user",
"content": "Add 5 and 3, then multiply the result by 2",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call12345",
"type": "function",
"function": {
"name": "add_numbers",
"arguments": {"a": 5, "b": 3},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call12345",
"name": "add_numbers",
"content": "8",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call23456",
"type": "function",
"function": {
"name": "multiply_numbers",
"arguments": {"x": 8, "y": 2},
},
}
],
},
{
"role": "tool",
"tool_call_id": "call23456",
"name": "multiply_numbers",
"content": "16",
},
{
"role": "assistant",
"content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.",
},
],
}
res = strategy.tokenize_prompt(multi_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
assert (
'[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}</s>' in decoded
)
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded
assert (
'[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}</s>'
in decoded
)
assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded
assert (
"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>"
in decoded
)
# Test tool calling with system message
system_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "search_database",
"description": "Search for information in database",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
},
"required": ["query"],
},
},
},
],
"messages": [
{
"role": "system",
"content": "You are a helpful assistant with access to a database.",
},
{
"role": "user",
"content": "Find information about Python programming",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "search123",
"type": "function",
"function": {
"name": "search_database",
"arguments": {"query": "Python programming"},
},
}
],
},
{
"role": "tool",
"tool_call_id": "search123",
"name": "search_database",
"content": "Python is a high-level programming language known for its simplicity.",
},
{
"role": "assistant",
"content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.",
},
],
}
res = strategy.tokenize_prompt(system_tool_calling)
# Validation
assert len(res["input_ids"]) > 0
assert len(res["labels"]) == len(res["input_ids"])
decoded = magistral_tokenizer.decode(res["input_ids"])
assert (
'<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]'
in decoded
)
# Test error handling - missing tool response
incomplete_tool_calling = {
"tools": [
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get current time",
"parameters": {"type": "object", "properties": {}},
},
},
],
"messages": [
{
"role": "user",
"content": "What time is it?",
},
{
"role": "assistant",
"tool_calls": [
{
"id": "time12345",
"type": "function",
"function": {
"name": "get_time",
"arguments": {},
},
}
],
},
{
"role": "assistant",
"content": "The current time is 12:00 PM.",
},
],
}
from mistral_common.exceptions import InvalidMessageStructureException
try:
strategy.tokenize_prompt(incomplete_tool_calling)
except InvalidMessageStructureException as e:
assert "Not the same number of function calls and responses" in str(e)
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
def test_magistral_tokenizer_call_method(
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
):
"""Test the __call__ method behavior matches HuggingFace standards"""
from copy import deepcopy
import numpy as np
import torch
hf_tokenizer = deepcopy(llama3_tokenizer)
hf_tokenizer.pad_token = hf_tokenizer.eos_token
test_text = "Hello, how are you?"
batch_texts = ["Hello world", "How are you?"]
# Test single string with return_tensors=None
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
mistral_result: dict[str, list[int]] = magistral_tokenizer(
test_text, return_tensors=None
)
assert isinstance(mistral_result, dict)
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
assert isinstance(
mistral_result["attention_mask"], type(hf_result["attention_mask"])
)
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
assert np.all(mistral_result["attention_mask"])
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
# Test single string with return_tensors='pt'
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
# Check structure and types
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
# Check shapes match (don't compare token dimension)
assert len(hf_result_pt["input_ids"].shape) == len(
mistral_result_pt["input_ids"].shape
)
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
assert (
mistral_result_pt["attention_mask"].shape
== mistral_result_pt["input_ids"].shape
)
assert torch.all(mistral_result_pt["attention_mask"] == 1)
# Test batch input with padding
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
batch_texts, return_tensors="pt", padding=True
)
# Check batch behavior
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
assert torch.any(
mistral_batch["attention_mask"][0] == 0
) # padding in shorter sequence
assert torch.all(
mistral_batch["attention_mask"][1] == 1
) # no padding in longer sequence
# Test numpy tensors
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
test_text, return_tensors="np"
)
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
# Test consistency with encode()
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
called: dict[str, torch.Tensor] = magistral_tokenizer(
test_text, return_tensors="pt"
)
assert encoded == called["input_ids"][0].tolist()
# Test Error handling
with pytest.raises(ValueError, match="Unsupported kwargs"):
magistral_tokenizer(test_text, unsupported_param=True)
with pytest.raises(
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
):
magistral_tokenizer(batch_texts, return_tensors="pt")
if __name__ == "__main__":
unittest.main()

View File

@@ -2,8 +2,6 @@
Tests for splitting reasoning/thinking from content into separate field
"""
import logging
import pytest
from datasets import Dataset
from transformers import AutoTokenizer
@@ -13,11 +11,6 @@ from axolotl.prompt_strategies.chat_template import (
)
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
@pytest.fixture(name="messages_w_reasoning")
def messages_w_reasoning_fixture():
@@ -64,7 +57,6 @@ def messages_w_reasoning_fixture():
@pytest.fixture(name="qwen3_tokenizer")
@enable_hf_offline
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument

View File

@@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3:
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "llama3",
@@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3:
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "llama3",
@@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3:
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "tokenizer_default",
@@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma:
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
transform_fn, _ = default(
DictDefault(
{
"chat_template": "tokenizer_default",

View File

@@ -6,8 +6,9 @@ import unittest
import pytest
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from tests.hf_offline_utils import enable_hf_offline
@@ -55,7 +56,8 @@ class TestDPOChatml:
# test that dpo.load works
load_dpo("chatml", cfg)
# now actually load the datasets with the strategy
train_ds, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
assert train_ds[0]["prompt"].startswith("<|im_start|>")
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
assert "chosen" in train_ds[0]

View File

@@ -2,14 +2,12 @@
tests for jinja_template_analyzer
"""
import logging
import pytest
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.logging import get_logger
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__, log_level="DEBUG")
class TestJinjaTemplateAnalyzer:

View File

@@ -0,0 +1,40 @@
"""
test suite for chunked cross entropy
"""
import pytest
import torch
from torch import nn
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
@pytest.fixture
def chunked_fixtures():
model_dim = 512
vocab_size = 1024 * 256
seq_len = 2048
batch_size = 1
lm_head = nn.Linear(model_dim, vocab_size)
hidden_state = torch.randn(batch_size, seq_len, model_dim)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()
logits = lm_head(hidden_state)
chunked_lm_loss = lm_loss(logits, labels)
logits_flattened = logits.view(-1, vocab_size)
labels_flattened = labels.view(-1)
loss = nn.functional.cross_entropy(
logits_flattened.float(), labels_flattened, reduction="mean"
)
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)

View File

@@ -1,10 +1,9 @@
"""
Test dataset loading under various conditions.
"""
"""Test dataset loading under various conditions."""
import shutil
import tempfile
from pathlib import Path
from typing import Any, Generator
from unittest.mock import patch
import pytest
@@ -12,8 +11,9 @@ from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.utils.data import load_tokenized_prepared_datasets
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
"""Test a configured dataloader."""
@pytest.fixture
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
def tokenizer(
self, tokenizer_huggyllama
) -> Generator[PreTrainedTokenizer, Any, Any]:
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
yield tokenizer_huggyllama
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -133,10 +141,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -145,7 +157,7 @@ class TestDatasetPreparation:
@enable_hf_offline
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"""Usual use case. Verify a directory of parquet files can be loaded."""
"""Usual use case. Verify a directory of parquet files can be loaded."""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
tmp_ds_dir.mkdir()
@@ -168,10 +180,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -203,10 +219,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -232,10 +252,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -261,10 +285,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 1
assert "input_ids" in dataset.features
@@ -286,10 +314,14 @@ class TestDatasetPreparation:
}
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
@@ -315,7 +347,10 @@ class TestDatasetPreparation:
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -335,20 +370,27 @@ class TestDatasetPreparation:
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_processes": 4,
}
)
# pylint: disable=duplicate-code
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
assert len(train_dataset) == 1800
assert "conversation" in train_dataset.features
assert "conversation" not in train_dataset.features
assert "chosen" in train_dataset.features
assert "rejected" in train_dataset.features
assert "prompt" in train_dataset.features
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
@@ -387,16 +429,18 @@ class TestDatasetPreparation:
)
with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
"axolotl.utils.data.shared.load_dataset_with_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
dataset, _ = load_tokenized_prepared_datasets(
tokenizer, cfg, prepared_path
)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
str(prepared_path),
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
@@ -428,10 +472,14 @@ class TestDatasetPreparation:
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
with patch(
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
):
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
assert len(dataset) == 2000
assert "input_ids" in dataset.features

View File

@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
`deduplicate_and_log_datasets` during the execution of the preprocess command.
"""
import hashlib
import unittest
from unittest.mock import patch
@@ -14,8 +13,7 @@ from datasets import Dataset
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
from axolotl.utils.dict import DictDefault
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
self.expected_dataset = Dataset.from_dict(self.expected_data)
def test_deduplication(self):
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=self.dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_datasets_are_none(self):
# Test when both datasets are None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=None
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
def test_only_train_is_none(self):
# Test when only train_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=None, eval_dataset=self.dataset
)
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
def test_only_eval_is_none(self):
# Test when only eval_dataset is None
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=self.dataset, eval_dataset=None
)
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
def test_exact_duplicates(self):
# Test when datasets are exact duplicates
duplicate_data = {
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset = Dataset.from_dict(expected_data)
# Run deduplication
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
eval_dataset, _ = deduplicate_and_log_datasets(
dataset=dataset, dataset_name="eval"
)
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset, eval_dataset=dataset
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset, other_dataset=dataset
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
# Run deduplication
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=dataset_train, eval_dataset=dataset_eval
train_dataset, eval_dataset = deduplicate_and_log_datasets(
dataset=dataset_train, other_dataset=dataset_eval
)
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -230,6 +210,7 @@ class TestDeduplicateRLDataset:
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
"dataset_processes": 4,
}
)
yield fixture
@@ -245,7 +226,9 @@ class TestDeduplicateRLDataset:
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -255,7 +238,8 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
train_dataset, _ = load_prepare_preference_datasets(cfg)
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset has been deduplicated
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
@@ -269,7 +253,9 @@ class TestDeduplicateRLDataset:
):
# pylint: disable=duplicate-code
with (
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset,
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
@@ -279,9 +265,10 @@ class TestDeduplicateRLDataset:
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
cfg.dataset_exact_deduplication = False
# Load the dataset without deduplication
train_dataset, _ = load_prepare_preference_datasets(cfg)
cfg.dataset_exact_deduplication = False
tokenizer = load_tokenizer(cfg)
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset retains duplicates
assert (
@@ -335,7 +322,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, _, _, _ = prepare_dataset(
train_dataset, _, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -362,7 +349,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
_, eval_dataset, _, _ = prepare_dataset(
_, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -389,7 +376,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
)
# Prepare dataset using the prepare_dataset function
train_dataset, eval_dataset, _, _ = prepare_dataset(
train_dataset, eval_dataset, _, _ = prepare_datasets(
self.cfg_1,
tokenizer,
processor=processor,
@@ -428,41 +415,8 @@ class TestWrongCollisions(unittest.TestCase):
self.eval_dataset = Dataset.from_dict(self.eval_data)
self.dataset = Dataset.from_dict(self.dataset_data)
@patch(
"axolotl.utils.data.utils.sha256",
side_effect=lambda x: (
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
if "sample 5" in x
else hashlib.sha256(x.encode("utf-8")).hexdigest()
),
)
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
self.assertEqual(
len(dedup_train),
2,
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
2,
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
)
self.assertEqual(
len(dedup_eval),
len(self.eval_dataset),
"The output eval dataset should have the same number of rows as the input eval dataset.",
)
self.assertEqual(
str(dedup_eval),
str(self.eval_dataset),
"The string representation of the output eval dataset should be identical to the input eval dataset.",
)
def test_deduplication_dataset_only(self):
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
self.assertEqual(
len(dedup_dataset), 3, "Dataset should have all original values"
)

View File

@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
from axolotl.loaders import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import _get_parallel_config_kwargs
class TestModelsUtils:
@@ -171,3 +172,42 @@ class TestModelsUtils:
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)
@pytest.mark.parametrize(
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
[
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
],
)
def test_get_parallel_config_kwargs(
self,
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
expected,
):
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
world_size,
tensor_parallel_size,
context_parallel_size,
dp_shard_size,
dp_replicate_size,
is_fsdp,
)
if expected[0] > 1:
assert res["tp_size"] == expected[0]
if expected[1] > 1:
assert res["cp_size"] == expected[1]
if expected[2] > 1:
assert res["dp_shard_size"] == expected[2]
if expected[3] > 1:
assert res["dp_replicate_size"] == expected[3]

Some files were not shown because too many files have changed in this diff Show More