Resolve merge conflicts: unify pretraining utils imports, add alias handling; fix rl.py per new RL dataset API; resolve config schema conflict and add sequence_len_overflow_handling field
This commit is contained in:
@@ -17,16 +17,23 @@ class BaseCliTest:
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||
result = cli_runner.invoke(
|
||||
cli, [command, "nonexistent.yml", "--launcher", "python"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def _test_basic_execution(
|
||||
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||
self,
|
||||
cli_runner,
|
||||
tmp_path: Path,
|
||||
valid_test_config: str,
|
||||
command: str,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Test basic execution with accelerate.
|
||||
|
||||
@@ -35,24 +42,37 @@ class BaseCliTest:
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
train: Whether to test training (default) or evaluation
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||
|
||||
with patch(mock_fn) as mock:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
|
||||
expected = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
f"axolotl.cli.{command}",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
"--debug=False",
|
||||
"--debug-text-only=False",
|
||||
"--debug-num-examples=0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
if command == "train":
|
||||
assert mock.call_args.args[0] == "accelerate"
|
||||
assert mock.call_args.args[1] == expected
|
||||
else:
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for evaluate CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
|
||||
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
|
||||
)
|
||||
|
||||
def test_evaluate_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -27,13 +31,15 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -55,7 +61,8 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
"2",
|
||||
"--sequence-len",
|
||||
"128",
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -65,3 +72,104 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
cfg = mock_evaluate.call_args[0][0]
|
||||
assert cfg.micro_batch_size == 2
|
||||
assert cfg.sequence_len == 128
|
||||
|
||||
def test_evaluate_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing evaluate commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -10,7 +12,7 @@ def test_inference_basic(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate"],
|
||||
["inference", str(config_path), "--launcher", "python"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
@@ -23,9 +25,124 @@ def test_inference_gradio(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||
["inference", str(config_path), "--launcher", "python", "--gradio"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
|
||||
"""Test inference with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
|
||||
"""Test inference with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
|
||||
"""Test inference with gradio and launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--gradio",
|
||||
"--",
|
||||
"--num_processes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify both gradio flag and launcher args are present
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--num_processes=2" in called_cmd
|
||||
assert "--gradio" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
|
||||
"""Test that existing inference commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -18,11 +18,10 @@ def test_build_command():
|
||||
assert result == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--learning-rate",
|
||||
"0.0001",
|
||||
"--batch-size",
|
||||
"8",
|
||||
"--debug",
|
||||
"--learning-rate=0.0001",
|
||||
"--batch-size=8",
|
||||
"--debug=True",
|
||||
"--use-fp16=False",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
|
||||
],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "No such option" in result.output
|
||||
assert "does not exist" in result.output
|
||||
|
||||
|
||||
def test_required_config_argument(cli_runner):
|
||||
|
||||
@@ -11,9 +11,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||
cli,
|
||||
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
from axolotl.cli.utils import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Tests for train CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -18,7 +20,9 @@ class TestTrainCommand(BaseCliTest):
|
||||
|
||||
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "train", train=True
|
||||
)
|
||||
|
||||
def test_train_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -37,7 +41,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -59,11 +64,10 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
"--learning-rate=1e-4",
|
||||
"--micro-batch-size=2",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -73,3 +77,177 @@ class TestTrainCommand(BaseCliTest):
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
|
||||
def test_train_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing train commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
def test_train_mixed_args_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with both regular CLI args and launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--learning-rate",
|
||||
"2e-4",
|
||||
"--micro-batch-size",
|
||||
"4",
|
||||
"--",
|
||||
"--nproc_per_node=8",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
assert "--learning-rate=2e-4" in called_cmd
|
||||
assert "--micro-batch-size=4" in called_cmd
|
||||
|
||||
def test_train_cloud_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with cloud and launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
cloud_path = tmp_path / "cloud.yml"
|
||||
cloud_path.write_text("provider: modal\ngpu: a100")
|
||||
|
||||
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--cloud",
|
||||
str(cloud_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=4",
|
||||
"--nnodes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cloud_train.assert_called_once()
|
||||
|
||||
# Verify cloud training was called with launcher args
|
||||
call_kwargs = mock_cloud_train.call_args.kwargs
|
||||
assert call_kwargs["launcher"] == "torchrun"
|
||||
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]
|
||||
|
||||
@@ -72,3 +72,160 @@ def test_fetch_from_github_network_error():
|
||||
with patch("requests.get", side_effect=requests.RequestException):
|
||||
with pytest.raises(requests.RequestException):
|
||||
fetch_from_github("examples/", None)
|
||||
|
||||
|
||||
def assert_launcher_args_in_command(
|
||||
mock_subprocess_call,
|
||||
launcher: str,
|
||||
expected_launcher_args: list[str],
|
||||
command_module: str,
|
||||
):
|
||||
"""
|
||||
Helper function to verify launcher arguments are properly passed in subprocess calls.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
expected_launcher_args: List of expected launcher arguments
|
||||
command_module: Expected module name (e.g., "axolotl.cli.train")
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
# Verify launcher
|
||||
assert (
|
||||
called_cmd[0] == launcher
|
||||
), f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
|
||||
# Verify launcher args are present
|
||||
for arg in expected_launcher_args:
|
||||
assert (
|
||||
arg in called_cmd
|
||||
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
|
||||
# Verify module is present
|
||||
assert "-m" in called_cmd, "Expected -m flag for module execution"
|
||||
assert (
|
||||
command_module in called_cmd
|
||||
), f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
|
||||
|
||||
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
"""
|
||||
Helper function to verify no unwanted launcher arguments are present.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
if launcher == "accelerate":
|
||||
# For accelerate, launcher args should be between 'launch' and '-m'
|
||||
launch_idx = called_cmd.index("launch")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[launch_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
elif launcher == "torchrun":
|
||||
# For torchrun, launcher args should be between 'torchrun' and '-m'
|
||||
torchrun_idx = called_cmd.index("torchrun")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_launcher_args():
|
||||
"""Fixture providing common launcher argument combinations for testing."""
|
||||
return {
|
||||
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
|
||||
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
|
||||
}
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_endpoint():
|
||||
"""Test that default RDZV args are added when rdzv_endpoint is present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should have added rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
# Original args should still be present
|
||||
assert "--nnodes=2" in result
|
||||
assert "--rdzv_endpoint=127.0.0.1:29400" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_backend():
|
||||
"""Test that existing rdzv_backend is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_backend
|
||||
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
|
||||
assert backend_count == 1
|
||||
assert "--rdzv_backend=static" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_id():
|
||||
"""Test that existing rdzv_id is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_id=my_job_123",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_id
|
||||
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
|
||||
assert id_count == 1
|
||||
assert "--rdzv_id=my_job_123" in result
|
||||
|
||||
# Should still add rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_without_endpoint():
|
||||
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any rdzv args
|
||||
assert "--rdzv_backend" not in result
|
||||
assert result == launcher_args
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_all_existing():
|
||||
"""Test that no defaults are added when all RDZV args are present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
"--rdzv_id=existing_job",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any additional args
|
||||
assert len(result) == len(launcher_args)
|
||||
assert result == launcher_args
|
||||
|
||||
@@ -4,26 +4,33 @@ shared pytest fixtures
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import (
|
||||
enable_hf_offline,
|
||||
hf_offline_context,
|
||||
)
|
||||
|
||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -411,7 +418,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
def temp_dir() -> Generator[str, None, None]:
|
||||
# Create a temporary directory
|
||||
_temp_dir = tempfile.mkdtemp()
|
||||
yield _temp_dir
|
||||
@@ -419,6 +426,11 @@ def temp_dir():
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def torch_manual_seed():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
@@ -529,6 +541,22 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
@pytest.fixture(name="min_base_cfg")
|
||||
def fixture_min_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
|
||||
604
tests/core/test_builders.py
Normal file
604
tests/core/test_builders.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""Unit tests for axolotl.core.builders"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||
|
||||
|
||||
@pytest.fixture(name="base_cfg")
|
||||
def fixture_base_cfg():
|
||||
"""
|
||||
Base config with all common arguments between SFT and RLHF
|
||||
"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
# Model and tokenizer settings
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
"sequence_len": 2048,
|
||||
"model_config_type": "llama", # example type
|
||||
# Basic training settings
|
||||
"micro_batch_size": 2,
|
||||
"eval_batch_size": 2,
|
||||
"num_epochs": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"max_steps": 100,
|
||||
"val_set_size": 0,
|
||||
# Optimizer settings
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"learning_rate": 0.00005,
|
||||
"weight_decay": 0.01,
|
||||
"adam_beta1": 0.998,
|
||||
"adam_beta2": 0.9,
|
||||
"adam_epsilon": 0.00001,
|
||||
"max_grad_norm": 1.0,
|
||||
# LR scheduler settings
|
||||
"lr_scheduler": "cosine",
|
||||
"lr_scheduler_kwargs": {"foo": "bar"},
|
||||
"warmup_steps": 10,
|
||||
"warmup_ratio": None,
|
||||
"cosine_min_lr_ratio": 0.1,
|
||||
"cosine_constant_lr_ratio": 0.2,
|
||||
# Checkpointing and saving
|
||||
"save_steps": 100,
|
||||
"output_dir": "./model-out",
|
||||
"save_safetensors": True,
|
||||
"save_total_limit": 4,
|
||||
"save_only_model": False,
|
||||
# Hardware/performance settings
|
||||
"gradient_checkpointing": False,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"context_parallel_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
"tf32": False,
|
||||
# Logging and evaluation
|
||||
"logging_steps": 10,
|
||||
"eval_steps": 50,
|
||||
"eval_strategy": "steps",
|
||||
"save_strategy": "steps",
|
||||
"include_tokens_per_second": True,
|
||||
# Other common settings
|
||||
"seed": 42,
|
||||
"remove_unused_columns": True,
|
||||
"ddp_timeout": 1800,
|
||||
"ddp_bucket_cap_mb": 25,
|
||||
"ddp_broadcast_buffers": False,
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="dpo_cfg")
|
||||
def fixture_dpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.DPO,
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_use_logits_to_keep": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="orpo_cfg")
|
||||
def fixture_orpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.ORPO,
|
||||
"orpo_alpha": 0.1,
|
||||
"max_prompt_len": 512,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="kto_cfg")
|
||||
def fixture_kto_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.KTO,
|
||||
"kto_desirable_weight": 1.0,
|
||||
"kto_undesirable_weight": 1.0,
|
||||
"max_prompt_len": 512,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="grpo_cfg")
|
||||
def fixture_grpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.GRPO,
|
||||
"trl": DictDefault(
|
||||
{
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": False, # run on CPU
|
||||
# "vllm_device": "auto",
|
||||
# "vllm_gpu_memory_utilization": 0.15,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": ["rewards.rand_reward_func"],
|
||||
}
|
||||
),
|
||||
# Must be evenly divisible by num_generations
|
||||
"micro_batch_size": 4,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="ipo_cfg")
|
||||
def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.IPO,
|
||||
"dpo_label_smoothing": 0,
|
||||
"beta": 0.1,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="simpo_cfg")
|
||||
def fixture_simpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.SIMPO,
|
||||
"rl_beta": 0.2,
|
||||
"cpo_alpha": 0.9,
|
||||
"simpo_gamma": 0.4,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="sft_cfg")
|
||||
def fixture_sft_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": None,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"flash_attention": False,
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="rm_cfg")
|
||||
def fixture_rm_cfg(sft_cfg):
|
||||
cfg = sft_cfg.copy()
|
||||
cfg.update(
|
||||
DictDefault(
|
||||
{
|
||||
"reward_model": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "argilla/distilabel-intel-orca-dpo-pairs",
|
||||
"type": "bradley_terry.chat_template",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="prm_cfg")
|
||||
def fixture_prm_cfg(sft_cfg):
|
||||
cfg = sft_cfg.copy()
|
||||
cfg.update(
|
||||
DictDefault(
|
||||
{
|
||||
"process_reward_model": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "trl-lib/math_shepherd",
|
||||
"type": "stepwise_supervised",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer(base_cfg):
|
||||
return load_tokenizer(base_cfg)
|
||||
|
||||
|
||||
@pytest.fixture(name="model")
|
||||
def fixture_model(base_cfg, tokenizer):
|
||||
model, _ = ModelLoader(base_cfg, tokenizer).load()
|
||||
return model
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
"""
|
||||
TestCase class for RLHF trainer builders
|
||||
"""
|
||||
|
||||
def _test_common_training_arguments(self, training_arguments, rl: str):
|
||||
"""Helper to test common arguments across all variants"""
|
||||
# Basic training settings
|
||||
if rl == "grpo":
|
||||
# grpo_cfg's micro_batch_size is diff from others
|
||||
assert training_arguments.per_device_train_batch_size == 4
|
||||
else:
|
||||
assert training_arguments.per_device_train_batch_size == 2
|
||||
assert training_arguments.gradient_accumulation_steps == 1
|
||||
assert training_arguments.max_steps == 100
|
||||
|
||||
# Optimizer settings
|
||||
assert training_arguments.learning_rate == 0.00005
|
||||
assert training_arguments.weight_decay == 0.01
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.max_grad_norm == 1.0
|
||||
|
||||
# LR scheduler settings
|
||||
assert training_arguments.lr_scheduler_type == "cosine"
|
||||
assert training_arguments.warmup_steps == 10
|
||||
assert training_arguments.cosine_min_lr_ratio == 0.1
|
||||
assert training_arguments.cosine_constant_lr_ratio == 0.2
|
||||
|
||||
# Other settings
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
|
||||
# TODO(wing): restore once trl releases 0.22.0
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=dpo_cfg.rl)
|
||||
# DPO specific
|
||||
assert training_arguments.beta == 0.1
|
||||
assert hasattr(training_arguments, "use_weighting")
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=orpo_cfg.rl)
|
||||
# ORPO specific
|
||||
assert training_arguments.beta == 0.1 # maps from orpo_alpha
|
||||
assert training_arguments.max_prompt_length == 512
|
||||
|
||||
def test_kto_training_arguments(self, kto_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(kto_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=kto_cfg.rl)
|
||||
# KTO specific
|
||||
assert training_arguments.desirable_weight == 1.0
|
||||
assert training_arguments.undesirable_weight == 1.0
|
||||
assert training_arguments.max_prompt_length == 512
|
||||
|
||||
def _write_rewards_file(self, rewards_dir: Path):
|
||||
"""
|
||||
Writes reward function to local tmp path to be loaded on trainer building
|
||||
"""
|
||||
# Create rewards.py in a directory we can import from
|
||||
rewards_dir.mkdir()
|
||||
rewards_file = rewards_dir / "rewards.py"
|
||||
rewards_file.write_text(
|
||||
"""import random
|
||||
def rand_reward_func(prompts, completions) -> list[float]:
|
||||
return [random.uniform(0, 1) for _ in completions]
|
||||
"""
|
||||
)
|
||||
|
||||
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
|
||||
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
self._write_rewards_file(rewards_dir)
|
||||
|
||||
# Add the directory to Python path so we can import the module
|
||||
sys.path.insert(0, str(rewards_dir))
|
||||
|
||||
try:
|
||||
builder = HFRLTrainerBuilder(grpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=grpo_cfg.rl)
|
||||
# GRPO specific
|
||||
assert training_arguments.beta == 0.001
|
||||
assert training_arguments.max_completion_length == 256
|
||||
assert training_arguments.use_vllm is False
|
||||
# assert training_arguments.vllm_device == "auto"
|
||||
# assert training_arguments.vllm_gpu_memory_utilization == 0.15
|
||||
assert training_arguments.num_generations == 4
|
||||
|
||||
# Test trainer creation to verify reward_funcs
|
||||
trainer = builder.build(100)
|
||||
|
||||
# Verify reward functions are properly loaded
|
||||
assert len(trainer.reward_funcs) == 1
|
||||
assert trainer.reward_funcs[0].__module__ == "rewards"
|
||||
assert trainer.reward_funcs[0].__name__ == "rand_reward_func"
|
||||
finally:
|
||||
# remove imported module from path
|
||||
if str(rewards_dir) in sys.path:
|
||||
sys.path.remove(str(rewards_dir))
|
||||
|
||||
def test_ipo_training_arguments(self, ipo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(ipo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=ipo_cfg.rl)
|
||||
# IPO specific
|
||||
assert training_arguments.beta == 0.1
|
||||
assert training_arguments.loss_type == "ipo"
|
||||
assert training_arguments.label_smoothing == 0
|
||||
|
||||
def test_simpo_training_arguments(self, simpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(simpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
self._test_common_training_arguments(training_arguments, rl=simpo_cfg.rl)
|
||||
# SIMPO specific
|
||||
assert training_arguments.beta == 0.2
|
||||
assert training_arguments.cpo_alpha == 0.9
|
||||
assert training_arguments.simpo_gamma == 0.4
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cfg_string", "dataset_name"),
|
||||
[
|
||||
(
|
||||
"dpo_cfg",
|
||||
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
||||
),
|
||||
(
|
||||
"ipo_cfg",
|
||||
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
||||
),
|
||||
(
|
||||
"grpo_cfg",
|
||||
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
||||
),
|
||||
("orpo_cfg", None), # don't use fixture for orpo to use smaller split
|
||||
("kto_cfg", None), # no fixture for kto
|
||||
(
|
||||
"simpo_cfg",
|
||||
"dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_custom_optimizer_cls_and_kwargs(
|
||||
self,
|
||||
request,
|
||||
cfg_string,
|
||||
dataset_name,
|
||||
tmp_path,
|
||||
model,
|
||||
tokenizer,
|
||||
):
|
||||
cfg = request.getfixturevalue(cfg_string)
|
||||
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
cfg["optimizer"] = "muon"
|
||||
|
||||
if cfg_string in ["dpo_cfg", "ipo_cfg", "grpo_cfg", "simpo_cfg"]:
|
||||
cfg["datasets"] = [DictDefault(ALPACA_MESSAGES_CONFIG_REVISION)]
|
||||
elif cfg_string == "kto_cfg":
|
||||
cfg["datasets"] = [
|
||||
DictDefault(
|
||||
{
|
||||
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
||||
"type": "llama3.ultra",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
)
|
||||
]
|
||||
elif cfg_string == "orpo_cfg":
|
||||
cfg["datasets"] = [
|
||||
DictDefault(
|
||||
{
|
||||
"path": "argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||
"type": "chat_template.argilla",
|
||||
"split": "train[:1%]",
|
||||
}
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
||||
cfg["dataset_processes"] = 4
|
||||
|
||||
if cfg_string == "grpo_cfg":
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
self._write_rewards_file(rewards_dir)
|
||||
|
||||
# Add the directory to Python path so we can import the module
|
||||
sys.path.insert(0, str(rewards_dir))
|
||||
|
||||
try:
|
||||
# Only use mock for the commented out configs
|
||||
if dataset_name is not None:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
mock_load_dataset.return_value = request.getfixturevalue(
|
||||
dataset_name
|
||||
)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
else:
|
||||
# Load actual datasets for orpo_cfg and kto_cfg
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
builder.train_dataset = train_dataset
|
||||
builder.eval_dataset = eval_dataset
|
||||
|
||||
trainer = builder.build(100)
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||
assert optimizer_cls is MuonOptimizerFactory
|
||||
assert optimizer_kwargs["lr"] == 0.00005
|
||||
assert optimizer_kwargs["weight_decay"] == 0.01
|
||||
assert optimizer_kwargs["betas"] == (0.998, 0.9)
|
||||
assert optimizer_kwargs["eps"] == 0.00001
|
||||
|
||||
# Ensure optimizer is created with correct class
|
||||
optim = trainer.create_optimizer()
|
||||
assert isinstance(optim, Muon)
|
||||
|
||||
finally:
|
||||
# remove imported module from path
|
||||
if cfg_string == "grpo_cfg" and str(rewards_dir) in sys.path:
|
||||
sys.path.remove(str(rewards_dir))
|
||||
|
||||
|
||||
class TestHFCausalTrainerBuilder:
|
||||
"""
|
||||
TestCase class for SFT trainer builder
|
||||
"""
|
||||
|
||||
def test_training_arguments(self, sft_cfg, model, tokenizer):
|
||||
builder = HFCausalTrainerBuilder(sft_cfg, model, tokenizer)
|
||||
trainer = builder.build(100)
|
||||
training_arguments = trainer.args
|
||||
|
||||
# Test common arguments
|
||||
assert training_arguments.per_device_train_batch_size == 2
|
||||
assert training_arguments.gradient_accumulation_steps == 1
|
||||
assert training_arguments.max_steps == 100
|
||||
|
||||
assert training_arguments.learning_rate == 0.00005
|
||||
assert training_arguments.weight_decay == 0.01
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.max_grad_norm == 1.0
|
||||
|
||||
assert training_arguments.lr_scheduler_type == "cosine"
|
||||
assert training_arguments.warmup_steps == 10
|
||||
assert training_arguments.cosine_min_lr_ratio == 0.1
|
||||
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
assert training_arguments.gradient_checkpointing is False
|
||||
|
||||
# SFT specific
|
||||
assert training_arguments.sample_packing is False
|
||||
assert training_arguments.eval_sample_packing is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cfg_string",
|
||||
[
|
||||
"sft_cfg",
|
||||
"rm_cfg",
|
||||
"prm_cfg",
|
||||
],
|
||||
)
|
||||
def test_custom_optimizer_cls_and_kwargs(
|
||||
self, request, cfg_string, model, tokenizer
|
||||
):
|
||||
cfg = request.getfixturevalue(cfg_string)
|
||||
builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
cfg["optimizer"] = "muon"
|
||||
|
||||
# need to load datasets for reward model and process reward model trainer
|
||||
if cfg_string in ["rm_cfg", "prm_cfg"]:
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
builder.train_dataset = dataset_meta.train_dataset
|
||||
builder.eval_dataset = dataset_meta.eval_dataset
|
||||
|
||||
trainer = builder.build(100)
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
|
||||
assert optimizer_cls is MuonOptimizerFactory
|
||||
assert optimizer_kwargs["lr"] == 0.00005
|
||||
assert optimizer_kwargs["weight_decay"] == 0.01
|
||||
assert optimizer_kwargs["betas"] == (0.998, 0.9)
|
||||
assert optimizer_kwargs["eps"] == 0.00001
|
||||
|
||||
# Ensure optimizer is created with correct class
|
||||
optim = trainer.create_optimizer()
|
||||
assert isinstance(optim, Muon)
|
||||
|
||||
|
||||
class TestTrainerClsPlugin:
|
||||
"""
|
||||
TestCase class for trainer builder with plugin
|
||||
"""
|
||||
|
||||
def test_trainer_cls_is_not_none_with_plugin(self, kto_cfg, model, tokenizer):
|
||||
"""
|
||||
Test that the trainer cls is not none with plugin
|
||||
|
||||
Fixes #2693
|
||||
"""
|
||||
cfg = kto_cfg.copy()
|
||||
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
||||
|
||||
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
||||
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
||||
try:
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
|
||||
builder.build(100)
|
||||
except TypeError as e:
|
||||
# Error raised if trainer_cls is None
|
||||
assert "'tuple' object has no attribute 'config'" not in str(e)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
# Another error happens, so we passed trainer_cls to builder
|
||||
pass
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Unit tests for axolotl.core.trainer_builder"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 0.00005,
|
||||
"save_steps": 100,
|
||||
"output_dir": "./model-out",
|
||||
"warmup_steps": 10,
|
||||
"gradient_checkpointing": False,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"sequence_len": 2048,
|
||||
"rl": True,
|
||||
"adam_beta1": 0.998,
|
||||
"adam_beta2": 0.9,
|
||||
"adam_epsilon": 0.00001,
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"model_config_type": "llama",
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer(cfg):
|
||||
return load_tokenizer(cfg)
|
||||
|
||||
|
||||
@pytest.fixture(name="model")
|
||||
def fixture_model(cfg, tokenizer):
|
||||
return ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
"""
|
||||
TestCase class for DPO trainer builder
|
||||
"""
|
||||
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
|
||||
|
||||
class TestTrainerClsPlugin:
|
||||
"""
|
||||
TestCase class for trainer builder with plugin
|
||||
"""
|
||||
|
||||
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
|
||||
"""
|
||||
Test that the trainer cls is not none with plugin
|
||||
|
||||
Fixes #2693
|
||||
"""
|
||||
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
||||
cfg.rl = RLType.KTO
|
||||
|
||||
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
||||
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
||||
with pytest.raises(
|
||||
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
|
||||
):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
|
||||
builder.build(100)
|
||||
@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
@@ -45,6 +44,7 @@ def min_cfg(temp_dir):
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -59,8 +59,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -100,13 +99,13 @@ class TestCutCrossEntropyIntegration:
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -134,8 +133,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
|
||||
62
tests/e2e/integrations/test_fp8.py
Normal file
62
tests/e2e/integrations/test_fp8.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Simple end-to-end smoke tests for FP8 mixed precision training
|
||||
"""
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0
|
||||
|
||||
|
||||
class FP8IntegrationTestCase:
|
||||
"""
|
||||
e2e smoke tests for FP8 mixed precision training with Axolotl
|
||||
"""
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_fp8_single_gpu_smoke(self, temp_dir):
|
||||
"""Smoke test for single GPU FP8 + torch.compile training"""
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3, # Very short smoke test
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"sdp_attention": True,
|
||||
"pad_to_seq_len": True,
|
||||
"sample_packing": True,
|
||||
"fp8": True,
|
||||
"torch_compile": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.train import train
|
||||
@@ -154,14 +153,14 @@ class TestPluginHooks:
|
||||
"max_steps": 5,
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,11 +5,9 @@ e2e tests for kd trainer support in Axolotl
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
@@ -18,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
@pytest.fixture(name="kd_min_cfg")
|
||||
def min_cfg(temp_dir):
|
||||
return {
|
||||
"base_model": "osllmai-community/Llama-3.2-1B",
|
||||
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"tokenizer_config": "winglian/qwen3-14b-math",
|
||||
"plugins": [
|
||||
"axolotl.integrations.kd.KDPlugin",
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
@@ -32,20 +30,22 @@ def min_cfg(temp_dir):
|
||||
"kd_ce_alpha": 0.1,
|
||||
"kd_alpha": 0.9,
|
||||
"kd_temperature": 1.0,
|
||||
"kd_beta": 0.0,
|
||||
"kd_normalize_topk": True,
|
||||
"dataloader_prefetch_factor": 8,
|
||||
"dataloader_num_workers": 4,
|
||||
"dataloader_pin_memory": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
|
||||
"type": "axolotl.integrations.kd.chat_template",
|
||||
"field_messages": "messages_combined",
|
||||
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
|
||||
"type": "chat_template",
|
||||
"split": "train",
|
||||
"logprobs_field": "llm_text_generation_vllm_logprobs",
|
||||
"temperature": 1.0,
|
||||
"preprocess_shards": 2,
|
||||
"split_thinking": True,
|
||||
"eot_tokens": ["<|im_end|>"],
|
||||
"data_files": ["train/batch-000000.parquet"],
|
||||
},
|
||||
],
|
||||
"skip_prepare_dataset": True,
|
||||
"val_set_size": 0.0,
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
|
||||
"output_dir": temp_dir,
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -81,18 +82,29 @@ class TestKnowledgeDistillation:
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"1",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
[True, False],
|
||||
@@ -112,13 +124,22 @@ class TestKnowledgeDistillation:
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"1",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -51,14 +50,14 @@ class LigerIntegrationTestCase:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -98,14 +97,14 @@ class LigerIntegrationTestCase:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -82,14 +81,14 @@ class TestLLMCompressorIntegration:
|
||||
},
|
||||
"save_compressed": save_compressed,
|
||||
},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
try:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for GEGLU activation function Triton kernels."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -20,8 +19,15 @@ def test_geglu_forward_shape():
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
)
|
||||
def test_geglu_forward_values(torch_seed):
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
torch.manual_seed(torch_seed)
|
||||
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
@@ -34,8 +40,15 @@ def test_geglu_forward_values():
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
def test_geglu_backward():
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
)
|
||||
def test_geglu_backward(torch_seed):
|
||||
"""Test GEGLU backward pass matches PyTorch autograd."""
|
||||
torch.manual_seed(torch_seed)
|
||||
|
||||
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
|
||||
grad_output = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
@@ -64,6 +64,7 @@ def sample_tensors():
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
@@ -103,23 +104,24 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert b.shape == (128,)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
@@ -127,6 +129,7 @@ def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -138,19 +141,20 @@ def test_matmul_lora(sample_tensors):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
out1 = matmul_lora(X, W, b, None, None, None, None)
|
||||
matmul = torch.matmul(X, W.t())
|
||||
expected1 = matmul + b
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
out2 = matmul_lora(X, W, b, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
expected2 = matmul + lora_term + b
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@@ -175,16 +179,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
@@ -243,16 +250,19 @@ def test_lora_mlp_with_adapters(
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
@@ -323,6 +333,7 @@ def test_lora_qkv(sample_tensors):
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
# pylint: disable=duplicate-code
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
@@ -330,16 +341,19 @@ def test_lora_qkv(sample_tensors):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -356,16 +370,19 @@ def test_lora_qkv(sample_tensors):
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
@@ -399,6 +416,7 @@ def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -411,7 +429,7 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -425,6 +443,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
b = sample_tensors["b"] # [out]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -436,13 +455,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
@@ -459,11 +478,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
b = torch.randn(out, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
result = matmul_lora(X, W, b, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
@@ -471,6 +491,7 @@ def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
b = sample_tensors["b"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -486,7 +507,7 @@ def test_gradient_flow(sample_tensors):
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
out = matmul_lora(X, W, b, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""E2E tests for sequence parallelism"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -12,8 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ...utils import check_tensorboard
|
||||
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestSequenceParallelism:
|
||||
"""Test case for training with sequence parallelism enabled"""
|
||||
@@ -57,6 +54,7 @@ class TestSequenceParallelism:
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,8 +67,9 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -94,7 +93,10 @@ class TestSequenceParallelism:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
threshold,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -103,13 +105,13 @@ class TestSequenceParallelism:
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -17,9 +15,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@@ -59,12 +54,14 @@ class TestPackedFlex:
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 2,
|
||||
"use_tensorboard": True,
|
||||
"save_strategy": "no",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -90,5 +87,5 @@ class TestPackedFlex:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ GRPO test suite
|
||||
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -106,7 +105,7 @@ def start_vllm(
|
||||
print(f"{i}: VLLM server failed to start: {str(exc)}")
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
time.sleep(period_seconds)
|
||||
@@ -118,7 +117,10 @@ def start_vllm(
|
||||
recursive_kill(process)
|
||||
with open("/tmp/vllm.log", "r", encoding="utf-8") as log_file:
|
||||
print(log_file.read())
|
||||
shutil.rmtree("/tmp/vllm.log")
|
||||
try:
|
||||
os.remove("/tmp/vllm.log")
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
raise RuntimeError(f"VLLM server process did not start within {wait} seconds.")
|
||||
|
||||
# return the process
|
||||
@@ -139,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
|
||||
os.kill(process.pid, 9)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky vllm tests in modal")
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
@@ -220,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -260,6 +264,101 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
**current_env,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
(recursive_kill(vllm_process))
|
||||
|
||||
@require_vllm
|
||||
def test_llama_lora_sp(self, temp_dir):
|
||||
rnd_reward_suffix = str(random.randint(1000, 9999))
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"chat_template": "llama3",
|
||||
"rl": "grpo",
|
||||
"trl": {
|
||||
"beta": 0.001,
|
||||
"max_completion_length": 256,
|
||||
"use_vllm": True,
|
||||
"num_generations": 4,
|
||||
"reward_funcs": [f"rewards_{rnd_reward_suffix}.rand_reward_func"],
|
||||
},
|
||||
"vllm": {
|
||||
"max_model_len": 800,
|
||||
"enable_prefix_caching": True,
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "openai/gsm8k",
|
||||
"name": "main",
|
||||
"type": f"rewards_{rnd_reward_suffix}.oai_gsm8k_transform",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"max_steps": 3,
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
self._utils_write_yaml_and_rewards(cfg, temp_dir, suffix=rnd_reward_suffix)
|
||||
|
||||
current_env = os.environ.copy()
|
||||
env = {
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
env=env,
|
||||
quiet=True,
|
||||
wait=300,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=cfg.vllm.max_model_len,
|
||||
enable_prefix_caching=cfg.vllm.enable_prefix_caching,
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
try:
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
str(2),
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
],
|
||||
env={
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
"NCCL_DEBUG": "INFO",
|
||||
**current_env,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
recursive_kill(vllm_process)
|
||||
|
||||
@@ -305,12 +404,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for multigpu eval
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
@@ -14,9 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@@ -43,12 +38,13 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.004,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
"split": "train[:5%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
@@ -56,6 +52,7 @@ class TestMultiGPUEval:
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -70,6 +67,7 @@ class TestMultiGPUEval:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -112,12 +110,13 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.0004,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
"split": "train[:5%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
@@ -125,6 +124,7 @@ class TestMultiGPUEval:
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -139,6 +139,7 @@ class TestMultiGPUEval:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
121
tests/e2e/multigpu/test_fp8_fsdp2.py
Normal file
121
tests/e2e/multigpu/test_fp8_fsdp2.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_fp8_training_success(temp_dir):
|
||||
"""Verify that FP8 training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
|
||||
|
||||
class TestFP8FSDP2:
|
||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_hopper
|
||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3, # Very short smoke test
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused", # Use standard optimizer for stability
|
||||
"lr_scheduler": "cosine",
|
||||
"sdp_attention": True,
|
||||
"pad_to_seq_len": True,
|
||||
"sample_packing": True,
|
||||
# FP8 configuration
|
||||
"fp8": True,
|
||||
"fp8_enable_fsdp_float8_all_gather": True,
|
||||
"torch_compile": True,
|
||||
# FSDP2 configuration
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_fp8_training_success(temp_dir)
|
||||
326
tests/e2e/multigpu/test_fsdp1.py
Normal file
326
tests/e2e/multigpu/test_fsdp1.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""Test module for FSDP1 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
|
||||
|
||||
class TestFSDP1:
|
||||
"""Test class for FSDP1 functionality."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_cpu_ram_efficient_loading",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter_config",
|
||||
[
|
||||
{
|
||||
"adapter": "lora",
|
||||
"load_in_4bit": False,
|
||||
},
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_lora_sft(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": adapter_config["adapter"],
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter_config",
|
||||
[
|
||||
{
|
||||
"adapter": "lora",
|
||||
"load_in_4bit": False,
|
||||
},
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_dpo_lora(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"sequence_len": 2048,
|
||||
"adapter": adapter_config["adapter"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
482
tests/e2e/multigpu/test_fsdp2.py
Normal file
482
tests/e2e/multigpu/test_fsdp2.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""Test module for FSDP2 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
|
||||
|
||||
class TestFSDP2:
|
||||
"""Test class for FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_cpu_ram_efficient_loading",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize("peft_use_dora", [True, False])
|
||||
def test_lora_sft(self, temp_dir, peft_use_dora):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"peft_use_dora": peft_use_dora,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -16,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@@ -69,12 +64,14 @@ class TestMultiGPUGemma3:
|
||||
},
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -96,5 +93,5 @@ class TestMultiGPUGemma3:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -18,9 +16,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
@@ -67,12 +62,14 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -94,7 +91,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -132,12 +129,14 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -159,7 +158,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
@@ -205,6 +204,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -212,6 +212,7 @@ class TestMultiGPULlama:
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -237,7 +238,7 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
@@ -283,6 +284,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -290,6 +292,7 @@ class TestMultiGPULlama:
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -315,7 +318,7 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -345,6 +348,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -364,6 +368,8 @@ class TestMultiGPULlama:
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -385,12 +391,15 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
[
|
||||
"FULL_STATE_DICT",
|
||||
# "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
|
||||
],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -412,11 +421,13 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 3,
|
||||
"save_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -436,6 +447,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -457,7 +469,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
@@ -496,6 +508,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -513,6 +526,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if attention_backend == "flash":
|
||||
@@ -538,7 +552,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
@@ -578,6 +592,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -593,10 +608,11 @@ class TestMultiGPULlama:
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -618,7 +634,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -674,12 +690,14 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -702,7 +720,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -748,12 +766,15 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -776,7 +797,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -822,12 +843,14 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -850,7 +873,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
@@ -896,6 +919,7 @@ class TestMultiGPULlama:
|
||||
"save_safetensors": True,
|
||||
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -917,5 +941,5 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
192
tests/e2e/multigpu/test_locking.py
Normal file
192
tests/e2e/multigpu/test_locking.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for FileLockLoader class."""
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFileLockLoader:
|
||||
"""Class with tests for FileLockLoader."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(self, temp_dir):
|
||||
"""Create a test configuration."""
|
||||
return DictDefault({"dataset_prepared_path": str(temp_dir)})
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, cfg):
|
||||
"""Create a FileLockLoader instance for testing."""
|
||||
return FileLockLoader(cfg)
|
||||
|
||||
def test_load_first_process(self, loader):
|
||||
"""Test load() when no ready flag exists (first process)."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should call the load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "test_data"
|
||||
|
||||
# Should create the ready flag
|
||||
assert loader.ready_flag_path.exists()
|
||||
|
||||
def test_load_subsequent_process(self, loader):
|
||||
"""Test load() when ready flag already exists (subsequent process)."""
|
||||
# Create ready flag first
|
||||
loader.ready_flag_path.touch()
|
||||
|
||||
mock_load_fn = Mock(return_value="loaded_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should still call load function (to load the prepared data)
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "loaded_data"
|
||||
|
||||
def test_load_concurrent_processes(self, cfg):
|
||||
"""Test that concurrent processes coordinate correctly."""
|
||||
results = []
|
||||
call_count = 0
|
||||
|
||||
def slow_load_fn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow loading
|
||||
return f"data_{call_count}"
|
||||
|
||||
def worker():
|
||||
loader = FileLockLoader(cfg)
|
||||
result = loader.load(slow_load_fn)
|
||||
results.append(result)
|
||||
|
||||
# Start multiple threads simultaneously
|
||||
threads = [threading.Thread(target=worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Only one thread should have done the initial loading
|
||||
# All should return data, but the load function should be called
|
||||
# once by the first process and once by each subsequent process
|
||||
assert len(results) == 3
|
||||
assert all(result.startswith("data_") for result in results)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
|
||||
"""Test that processes wait for the ready flag to appear."""
|
||||
mock_load_fn = Mock(return_value="waiting_data")
|
||||
mock_ready_flag_path = Mock()
|
||||
exists_call_count = 0
|
||||
|
||||
def mock_exists():
|
||||
nonlocal exists_call_count
|
||||
exists_call_count += 1
|
||||
|
||||
if exists_call_count == 1:
|
||||
# First check: ready flag exists (not first process)
|
||||
return True
|
||||
if exists_call_count <= 3:
|
||||
# While loop checks: flag doesn't exist yet
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_ready_flag_path.exists.side_effect = mock_exists
|
||||
|
||||
# Replace the ready_flag_path with our mock
|
||||
original_path = loader.ready_flag_path
|
||||
loader.ready_flag_path = mock_ready_flag_path
|
||||
|
||||
try:
|
||||
result = loader.load(mock_load_fn)
|
||||
finally:
|
||||
# Restore original path
|
||||
loader.ready_flag_path = original_path
|
||||
|
||||
# Should have slept twice while waiting
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(1)
|
||||
|
||||
# Should eventually call load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "waiting_data"
|
||||
|
||||
def test_complete_workflow_with_cleanup(self, loader):
|
||||
"""Test the complete load -> cleanup workflow."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
# First process calls load (this should set up counter)
|
||||
result = loader.load(mock_load_fn)
|
||||
assert result == "test_data"
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.exists()
|
||||
|
||||
# Cleanup should remove everything since there's only one process
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_multiple_processes_workflow(self, loader):
|
||||
"""Test workflow with multiple processes."""
|
||||
# Simulate multiple processes by manually setting up counter
|
||||
loader.ready_flag_path.touch()
|
||||
loader.counter_path.write_text("3") # 3 processes
|
||||
|
||||
# First process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "2"
|
||||
|
||||
# Second process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "1"
|
||||
|
||||
# Last process cleanup
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_load_exception_handling(self, loader):
|
||||
"""Test behavior when load_fn raises an exception."""
|
||||
|
||||
def failing_load_fn():
|
||||
raise ValueError("Load failed")
|
||||
|
||||
with pytest.raises(ValueError, match="Load failed"):
|
||||
loader.load(failing_load_fn)
|
||||
|
||||
# Ready flag should not be created on failure
|
||||
assert not loader.ready_flag_path.exists()
|
||||
|
||||
def test_file_lock_called(self, loader):
|
||||
"""Test that FileLock is properly used."""
|
||||
mock_load_fn = Mock(return_value="locked_data")
|
||||
|
||||
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
|
||||
mock_context = MagicMock()
|
||||
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
|
||||
mock_filelock.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
loader.load(mock_load_fn)
|
||||
|
||||
# Verify FileLock was called with correct path
|
||||
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
|
||||
|
||||
# Verify context manager was used
|
||||
mock_filelock.return_value.__enter__.assert_called_once()
|
||||
mock_filelock.return_value.__exit__.assert_called_once()
|
||||
@@ -1,97 +0,0 @@
|
||||
"""
|
||||
E2E tests for multigpu qwen2
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMultiGPUQwen2:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"load_in_4bit": True,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
# "gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {
|
||||
"use_reentrant": False,
|
||||
},
|
||||
"fsdp": [
|
||||
"full_shard",
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for multigpu post-training use Ray Train
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -12,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
require_torch_lt_2_6_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -53,6 +52,7 @@ class TestMultiGPURay:
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -60,6 +60,7 @@ class TestMultiGPURay:
|
||||
"use_tensorboard": True,
|
||||
"use_ray": True,
|
||||
"ray_num_workers": 2,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -80,7 +81,7 @@ class TestMultiGPURay:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@@ -112,12 +113,14 @@ class TestMultiGPURay:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -138,5 +141,73 @@ class TestMultiGPURay:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--use-ray",
|
||||
"--ray-num-workers",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
69
tests/e2e/multigpu/test_tp.py
Normal file
69
tests/e2e/multigpu/test_tp.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""multigpu e2e test for tensor parallelism."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
"""Test class for Tensor Parallel functionality."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TP doesn't work with models with tied weights (embeddings)"
|
||||
)
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"tensor_parallel_size": 2,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
@@ -21,8 +21,13 @@ from axolotl.kernels.lora import (
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
find_self_attn_in_layer,
|
||||
get_attention_cls_from_config,
|
||||
get_layers,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -80,7 +85,7 @@ def small_llama_model():
|
||||
)
|
||||
def test_attention_patching_integration(model_name, attention_cls):
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": model_name}
|
||||
cfg = DictDefault({"base_model": model_name})
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(attention_cls, "forward")
|
||||
@@ -391,7 +396,7 @@ def test_model_architecture(model_config):
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration():
|
||||
def test_kernel_training_integration(temp_dir):
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
@@ -421,6 +426,14 @@ def test_kernel_training_integration():
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Load model
|
||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
@@ -466,3 +479,103 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
assert cfg.lora_mlp_kernel is True
|
||||
assert cfg.lora_qkv_kernel is True
|
||||
assert cfg.lora_o_kernel is True
|
||||
|
||||
# Get the attention class before patching to check for side effects
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load the model (this should trigger the patches)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
# Test side effects of patch_self_attn_lora
|
||||
assert hasattr(attention_cls, "_original_forward")
|
||||
assert attention_cls.forward != original_forward_method
|
||||
|
||||
# Find at least one self-attention module and verify it has the patched methods
|
||||
found_patched_attn = False
|
||||
for layer in model.model.model.layers:
|
||||
if hasattr(layer, "self_attn"):
|
||||
self_attn = layer.self_attn
|
||||
if all(
|
||||
hasattr(self_attn, proj)
|
||||
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
):
|
||||
# These methods should be added by apply_lora_kernel_patches
|
||||
assert hasattr(self_attn, "apply_qkv") and callable(self_attn.apply_qkv)
|
||||
assert hasattr(self_attn, "apply_o") and callable(self_attn.apply_o)
|
||||
|
||||
found_patched_attn = True
|
||||
break
|
||||
|
||||
assert found_patched_attn
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
|
||||
model
|
||||
)
|
||||
|
||||
# Verify patch was not applied
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for multipack fft llama using 4d attention masks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class Test4dMultipackLlama(unittest.TestCase):
|
||||
"""
|
||||
@@ -61,12 +55,12 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -109,12 +103,12 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
import transformers
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -70,13 +69,14 @@ class TestActivationCheckpointing:
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"save_first_step": False,
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,13 +2,9 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestFAXentropyLlama:
|
||||
"""
|
||||
@@ -69,6 +62,7 @@ class TestFAXentropyLlama:
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -79,12 +73,11 @@ class TestFAXentropyLlama:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for falcon
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestFalconPatched(unittest.TestCase):
|
||||
"""
|
||||
@@ -64,12 +58,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -106,12 +100,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
82
tests/e2e/patched/test_flattening.py
Normal file
82
tests/e2e/patched/test_flattening.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
E2E tests for flattening batches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestFAFlattening:
|
||||
"""
|
||||
Test case for Llama models using LoRA w batch flattening
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"batch_flattening": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"field_messages": "conversations",
|
||||
"message_field_content": "value",
|
||||
"message_field_role": "from",
|
||||
"type": "chat_template",
|
||||
"split": "train[:2%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
131
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Integration tests for FSDP Params4bit patches."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_bnb_torch_function_patch,
|
||||
patched_torch_function,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_params4bit():
|
||||
"""Create a mock Params4bit instance with test attributes."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.requires_grad = True
|
||||
mock_instance.quant_state = "test_state"
|
||||
mock_instance.blocksize = 128
|
||||
mock_instance.compress_statistics = True
|
||||
mock_instance.quant_type = "fp4"
|
||||
mock_instance.quant_storage = "test_storage"
|
||||
mock_instance.module = "test_module"
|
||||
mock_instance.bnb_quantized = True
|
||||
return mock_instance
|
||||
|
||||
|
||||
class TestBnbTorchFunctionPatch:
|
||||
"""Test the Params4bit.__torch_function__ patch."""
|
||||
|
||||
def test_apply_patch(self):
|
||||
"""Test that the patch can be applied."""
|
||||
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
|
||||
apply_bnb_torch_function_patch()
|
||||
assert hasattr(mock_cls, "__torch_function__")
|
||||
assert isinstance(mock_cls.__torch_function__, classmethod)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
|
||||
"""Test that torch.chunk preserves Params4bit attributes."""
|
||||
mock_cls = Mock()
|
||||
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))
|
||||
|
||||
with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
|
||||
result = patched_torch_function(
|
||||
mock_cls,
|
||||
torch.chunk,
|
||||
(type(mock_params4bit),),
|
||||
args=(mock_params4bit, 2),
|
||||
)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
|
||||
# Check that Params4bit constructor was called with preserved attributes
|
||||
assert mock_cls.call_count == 2
|
||||
for call in mock_cls.call_args_list:
|
||||
kwargs = call[1]
|
||||
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
|
||||
assert kwargs["quant_state"] == mock_params4bit.quant_state
|
||||
assert kwargs["blocksize"] == mock_params4bit.blocksize
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_other_functions_fallback(self, mock_params4bit):
|
||||
"""Test that non-chunk/split functions use Parameter fallback."""
|
||||
mock_cls = Mock()
|
||||
fallback_result = torch.tensor([5, 6, 7])
|
||||
|
||||
with patch(
|
||||
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
|
||||
) as mock_fallback:
|
||||
result = patched_torch_function(
|
||||
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
|
||||
)
|
||||
|
||||
# Should call Parameter.__torch_function__ and return its result
|
||||
mock_fallback.assert_called_once()
|
||||
assert result is fallback_result
|
||||
mock_cls.assert_not_called()
|
||||
|
||||
|
||||
class TestFSDPPatchIntegration:
|
||||
"""Test FSDP patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_all_patches_together(self):
|
||||
"""Test that all patches can be applied together."""
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
# Store original methods before patching
|
||||
original_torch_function = getattr(
|
||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
original_init_sharded = FSDPParam._init_sharded_param
|
||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||
|
||||
# Apply patches
|
||||
apply_bnb_torch_function_patch()
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
# Verify patches were applied
|
||||
current_torch_function = getattr(
|
||||
bnb.nn.modules.Params4bit, "__torch_function__", None
|
||||
)
|
||||
if original_torch_function is not None:
|
||||
assert (
|
||||
current_torch_function != original_torch_function
|
||||
), "Params4bit.__torch_function__ was not patched"
|
||||
else:
|
||||
assert (
|
||||
current_torch_function is not None
|
||||
), "Params4bit.__torch_function__ was not added"
|
||||
|
||||
# Check that FSDP methods were patched
|
||||
assert (
|
||||
# pylint: disable=protected-access
|
||||
FSDPParam._init_sharded_param
|
||||
!= original_init_sharded
|
||||
), "_init_sharded_param was not patched"
|
||||
assert (
|
||||
FSDPParam.init_unsharded_param != original_init_unsharded
|
||||
), "init_unsharded_param was not patched"
|
||||
@@ -2,14 +2,11 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip("FIXME, mostly underused functionality")
|
||||
class TestFusedLlama(unittest.TestCase):
|
||||
@@ -35,7 +29,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -59,6 +52,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"max_steps": 10,
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -67,8 +61,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for llama w/ S2 attn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="FIXME?")
|
||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
@@ -64,13 +58,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -107,13 +101,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
@@ -61,6 +55,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -70,8 +65,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -115,12 +109,12 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,20 +2,14 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
@@ -23,6 +17,7 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -61,12 +56,12 @@ class TestMistral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -103,12 +98,12 @@ class TestMistral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for mixtral
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMixtral(unittest.TestCase):
|
||||
"""
|
||||
@@ -58,12 +52,12 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -97,12 +91,12 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
|
||||
@@ -49,6 +49,7 @@ class TestLlamaPeftEmbeddings:
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"embeddings_skip_upcast": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPhiMultipack(unittest.TestCase):
|
||||
"""
|
||||
@@ -60,13 +54,13 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -112,13 +106,13 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
E2E tests for resuming training
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestResumeLlama:
|
||||
"""
|
||||
@@ -64,6 +58,7 @@ class TestResumeLlama:
|
||||
"max_steps": 15,
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -72,8 +67,7 @@ class TestResumeLlama:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -83,7 +77,6 @@ class TestResumeLlama:
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,480 +0,0 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
import functools
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
# Create test tensors
|
||||
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
|
||||
attention_mask = torch.ones(batch_size, seq_len)
|
||||
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
|
||||
labels = input_ids.clone()
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Verify that RuntimeError is raised when no group is registered
|
||||
with pytest.raises(
|
||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
||||
):
|
||||
get_ring_attn_group()
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_register_ring_attn(
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
mock_rank.return_value = 3 # GPU #3
|
||||
mock_group = MagicMock()
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
# Clean up
|
||||
set_ring_attn_group(None)
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
"""Set up mocks for all tests in this class."""
|
||||
# Mock the ring_flash_attn module
|
||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
||||
|
||||
@pytest.fixture
|
||||
def base_cfg(self):
|
||||
"""Create a base configuration for testing."""
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_updates, expected_values, should_pass, error_msg",
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
"pad_to_sequence_len": True,
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"micro_batch_size must be set to 1",
|
||||
),
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"valid_config",
|
||||
"default_sp_degree",
|
||||
"without_flash_attention",
|
||||
"sample_packing_with_large_batch",
|
||||
"valid_grpo",
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg
|
||||
cfg.update(config_updates)
|
||||
|
||||
if should_pass:
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check expected values
|
||||
for key, value in expected_values.items():
|
||||
assert getattr(config, key) == value
|
||||
else:
|
||||
# Should raise exception
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
assert error_msg in str(excinfo.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ring_attn_func, sample_packing, expected_func",
|
||||
[
|
||||
(None, True, RingAttnFunc.VARLEN_LLAMA3),
|
||||
(None, False, RingAttnFunc.BATCH_RING),
|
||||
],
|
||||
ids=["default_with_sample_packing", "default_without_sample_packing"],
|
||||
)
|
||||
def test_ring_attn_func_validation(
|
||||
self, base_cfg, ring_attn_func, sample_packing, expected_func
|
||||
):
|
||||
"""Test ring_attn_func validation and defaults."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
|
||||
if ring_attn_func is not None:
|
||||
cfg["ring_attn_func"] = ring_attn_func
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check ring_attn_func value
|
||||
assert config.ring_attn_func.value == expected_func
|
||||
|
||||
def test_invalid_ring_attn_func(self, base_cfg):
|
||||
"""Test that an invalid ring_attn_func is rejected."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
|
||||
# Should raise ValidationError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
"""Mock torch.distributed functions for testing."""
|
||||
# Mock is_initialized to return True
|
||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
||||
|
||||
# Mock get_rank to return 0 by default
|
||||
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
|
||||
|
||||
# Mock get_world_size to return 2 by default
|
||||
monkeypatch.setattr(
|
||||
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
|
||||
)
|
||||
|
||||
# Mock the process group
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
# Mock update_ring_attn_params
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Check that sequence dimension was sharded correctly
|
||||
assert result["input_ids"].shape[1] == seq_len // 2
|
||||
assert result["attention_mask"].shape[1] == seq_len // 2
|
||||
|
||||
# Verify content: rank 0 should get the first half of the sequence
|
||||
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
|
||||
assert torch.equal(
|
||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verify content: rank 1 should get the second half of the sequence
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Checks for both ranks
|
||||
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
||||
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
||||
|
||||
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
|
||||
# # Rank 0 should get chunks [0, 1] and [6, 7]
|
||||
# # Rank 1 should get chunks [2, 3] and [4, 5]
|
||||
# if seq_len == 8:
|
||||
# # Create expected tensors for comparison
|
||||
# rank0_expected = torch.cat(
|
||||
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
||||
# )
|
||||
|
||||
# rank1_expected = torch.cat(
|
||||
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
||||
# )
|
||||
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Use the partially applied function
|
||||
result, _, _ = rank0_ring_parallel(batch=batch)
|
||||
|
||||
# Verify it works as expected
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
assert torch.equal(
|
||||
result["input_ids"],
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verification should pass
|
||||
assert "position_ids" in result
|
||||
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
@@ -2,12 +2,8 @@
|
||||
e2e tests for unsloth qlora
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -15,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(
|
||||
@@ -69,19 +62,19 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||
@@ -120,19 +113,19 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -176,17 +169,17 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for packed training w/ flex attention
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPackedFlex(unittest.TestCase):
|
||||
"""
|
||||
@@ -55,6 +49,7 @@ class TestPackedFlex(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -64,11 +59,10 @@ class TestPackedFlex(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -2,12 +2,9 @@
|
||||
E2E tests for relora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -15,9 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestReLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
@@ -40,9 +34,10 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora_steps": 50,
|
||||
"relora_warmup_steps": 10,
|
||||
"relora_anneal_steps": 10,
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 50,
|
||||
"jagged_restart_warmup_steps": 10,
|
||||
"jagged_restart_anneal_steps": 10,
|
||||
"relora_prune_ratio": 0.9,
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
@@ -71,13 +66,13 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
|
||||
83
tests/e2e/test_activation_offloading.py
Normal file
83
tests/e2e/test_activation_offloading.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
E2E tests for activation offloading
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
class TestActivationOffloading:
|
||||
"""
|
||||
E2E test cases for activation offloading
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter",
|
||||
["lora", "qlora", None],
|
||||
)
|
||||
def test_activation_offloading(
|
||||
self,
|
||||
temp_dir,
|
||||
adapter,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"eos_token": "<|im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "chatml",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": True,
|
||||
"activation_offloading": True,
|
||||
"save_first_step": False,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
}
|
||||
)
|
||||
if adapter == "lora":
|
||||
cfg["adapter"] = "lora"
|
||||
if adapter == "qlora":
|
||||
cfg["adapter"] = "qlora"
|
||||
cfg["load_in_4bit"] = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for deepseekv3
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestDeepseekV3:
|
||||
"""
|
||||
@@ -73,12 +67,12 @@ class TestDeepseekV3:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -123,12 +117,12 @@ class TestDeepseekV3:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
"""E2E tests for lora llama"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -17,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestDPOLlamaLora(unittest.TestCase):
|
||||
"""
|
||||
@@ -63,6 +56,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -112,6 +106,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -161,6 +156,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -210,6 +206,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -258,6 +255,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -309,6 +307,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -377,6 +376,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"""
|
||||
@@ -54,13 +48,13 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -100,12 +94,12 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""E2E smoke test for evaluate CLI command"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
@@ -9,8 +8,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestE2eEvaluate:
|
||||
"""Test cases for evaluate CLI"""
|
||||
@@ -39,6 +36,7 @@ class TestE2eEvaluate:
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for falcon
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestFalcon(unittest.TestCase):
|
||||
"""
|
||||
@@ -66,13 +60,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -122,13 +116,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -164,13 +158,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,21 +2,15 @@
|
||||
E2E tests for gemma2
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestGemma2:
|
||||
"""
|
||||
@@ -74,8 +68,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -126,8 +119,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -2,21 +2,15 @@
|
||||
E2E tests for gemma3_text
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestGemma3Text:
|
||||
"""
|
||||
@@ -69,12 +63,12 @@ class TestGemma3Text:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -120,12 +114,12 @@ class TestGemma3Text:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -11,11 +11,11 @@ class TestImports(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_import_causal_trainer(self):
|
||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFCausalTrainerBuilder,
|
||||
)
|
||||
|
||||
def test_import_rl_trainer(self):
|
||||
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFRLTrainerBuilder,
|
||||
)
|
||||
|
||||
@@ -2,10 +2,6 @@
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -13,9 +9,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlama:
|
||||
"""
|
||||
@@ -52,13 +45,13 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -100,13 +93,13 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -145,13 +138,13 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -186,13 +179,13 @@ class TestLlama:
|
||||
"batch_flattening": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"""
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
"""E2E tests for llama pretrain"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -15,27 +9,19 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
"""Test case for Llama models w pretraining"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"pretrain_multipack_attn",
|
||||
[True, False],
|
||||
("sample_packing", "pretrain_multipack_attn"),
|
||||
[
|
||||
(False, False),
|
||||
(True, True),
|
||||
(True, False),
|
||||
],
|
||||
)
|
||||
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
|
||||
if not sample_packing and pretrain_multipack_attn:
|
||||
return
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -67,22 +53,22 @@ class TestPretrainLlama:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
loss_threshold = 3.5
|
||||
loss_threshold = 3.6
|
||||
if sample_packing and not pretrain_multipack_attn:
|
||||
loss_threshold = 6.5
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlamaVision(unittest.TestCase):
|
||||
"""
|
||||
@@ -38,7 +32,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"val_set_size": 0,
|
||||
"chat_template": "llama3_2_vision",
|
||||
"datasets": [
|
||||
@@ -60,13 +54,13 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -86,7 +80,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"lora_target_modules": r"model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||
"val_set_size": 0,
|
||||
"chat_template": "llama3_2_vision",
|
||||
"datasets": [
|
||||
@@ -107,12 +101,12 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -52,6 +52,8 @@ class TestLoadModelUtils:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"tensor_parallel_size": 1,
|
||||
"context_parallel_size": 1,
|
||||
}
|
||||
)
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
@@ -55,13 +49,13 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="skipping until upstreamed into transformers")
|
||||
class TestMamba(unittest.TestCase):
|
||||
@@ -57,13 +51,13 @@ class TestMamba(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": None,
|
||||
"save_safetensors": False,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
"""
|
||||
@@ -61,13 +55,13 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -102,6 +96,7 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -111,8 +106,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
E2E tests for mixtral
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -17,9 +14,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMixtral(unittest.TestCase):
|
||||
"""
|
||||
@@ -67,13 +61,13 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -123,13 +117,13 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -178,6 +172,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -187,8 +182,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -237,6 +231,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -246,8 +241,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -283,6 +277,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -292,8 +287,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,20 +2,20 @@
|
||||
E2E tests for custom optimizers using Llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
from .utils import (
|
||||
check_model_output_exists,
|
||||
require_torch_2_5_1,
|
||||
require_torch_2_6_0,
|
||||
require_torch_2_7_0,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
|
||||
class TestCustomOptimizers(unittest.TestCase):
|
||||
@@ -56,13 +56,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"optimizer": "optimi_adamw",
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -102,13 +102,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adopt_adamw",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -149,18 +149,61 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"optimizer": "muon",
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_7_0
|
||||
def test_dion(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "dion",
|
||||
"dion_lr": 0.01,
|
||||
"dion_momentum": 0.95,
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -188,19 +231,20 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "constant",
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_6_0
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -236,13 +280,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"adam_epsilon2": 1e-16,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -2,13 +2,10 @@
|
||||
E2E tests for packed training
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -16,9 +13,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPackedLlama(unittest.TestCase):
|
||||
"""
|
||||
@@ -54,6 +48,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -63,11 +58,10 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestPhi(unittest.TestCase):
|
||||
"""
|
||||
@@ -59,12 +53,12 @@ class TestPhi(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -109,12 +103,12 @@ class TestPhi(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
58
tests/e2e/test_preprocess.py
Normal file
58
tests/e2e/test_preprocess.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""E2E Test the preprocess cli"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class TestPreprocess:
|
||||
"""test cases for preprocess"""
|
||||
|
||||
def test_w_deepspeed(self, temp_dir):
|
||||
"""make sure preproces doesn't choke when using deepspeed in the config"""
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"preprocess",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
assert (Path(temp_dir) / "last_run_prepared").exists()
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for process reward model w/ lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
"""
|
||||
@@ -55,12 +49,12 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
113
tests/e2e/test_profiler.py
Normal file
113
tests/e2e/test_profiler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
e2e gpu test for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="profiler_base_cfg")
|
||||
def fixture_profiler_base_cfg():
|
||||
cfg = DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
tokenizer_type="AutoTokenizer",
|
||||
sequence_len=1024,
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
lora_r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
lora_target_linear=True,
|
||||
val_set_size=0.02,
|
||||
special_tokens={"pad_token": "<|endoftext|>"},
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
num_epochs=1,
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=0.00001,
|
||||
optimizer="adamw_torch_fused",
|
||||
lr_scheduler="cosine",
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
class TestProfiler:
|
||||
"""
|
||||
test cases for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=1,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"profiler_steps_start",
|
||||
[3, 5],
|
||||
)
|
||||
def test_profiler_saves_past_end(
|
||||
self, profiler_base_cfg, temp_dir, profiler_steps_start
|
||||
):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=profiler_steps_start,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=6,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert not (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
135
tests/e2e/test_qat.py
Normal file
135
tests/e2e/test_qat.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
E2E tests for QAT
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestQATLlama:
|
||||
"""
|
||||
Test case for QAT Llama models
|
||||
"""
|
||||
|
||||
def test_qat(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"message_property_mappings": {
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
"drop_system_message": True,
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"chat_template": "chatml",
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int8",
|
||||
"group_size": 8,
|
||||
},
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
def test_qat_dpo(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int8",
|
||||
"group_size": 8,
|
||||
},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_preference_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
loss_threshold = 2.3
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
350
tests/e2e/test_quantization.py
Normal file
350
tests/e2e/test_quantization.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Tests for axolotl.utils.quantization
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
||||
from torchao.quantization.granularity import PerAxis, PerGroup
|
||||
from torchao.quantization.linear_activation_quantized_tensor import (
|
||||
LinearActivationQuantizedTensor,
|
||||
)
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
)
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.quantization import (
|
||||
convert_qat_model_for_ptq,
|
||||
get_ptq_config,
|
||||
prepare_model_for_qat,
|
||||
quantize_model_for_ptq,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
from tests.e2e.utils import require_torch_2_6_0
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-135M",
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
with torch.device(dummy_model.device):
|
||||
dummy_model.model.embed_tokens = torch.nn.Embedding(
|
||||
dummy_model.model.embed_tokens.weight.shape[0],
|
||||
dummy_model.model.embed_tokens.weight.shape[1],
|
||||
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
return dummy_model
|
||||
|
||||
|
||||
ptq_config_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
(
|
||||
TorchIntDType.uint4,
|
||||
None,
|
||||
None,
|
||||
UIntXWeightOnlyConfig,
|
||||
{"dtype": torch.uint4, "group_size": None},
|
||||
),
|
||||
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
|
||||
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
|
||||
(
|
||||
TorchIntDType.int4,
|
||||
TorchIntDType.int4,
|
||||
None,
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
{},
|
||||
),
|
||||
(
|
||||
TorchIntDType.int8,
|
||||
TorchIntDType.int8,
|
||||
None,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
{},
|
||||
),
|
||||
]
|
||||
|
||||
ptq_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
|
||||
(TorchIntDType.int8, None, 8, False, None),
|
||||
(TorchIntDType.int4, None, 4, True, None),
|
||||
(TorchIntDType.uint4, None, 8, False, None),
|
||||
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
|
||||
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
|
||||
(TorchIntDType.int8, None, None, False, ValueError),
|
||||
(TorchIntDType.int4, None, None, False, ValueError),
|
||||
]
|
||||
|
||||
|
||||
class TestQuantization:
|
||||
"""
|
||||
Test quantization utilities
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
|
||||
ptq_config_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_get_ptq_config(
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
):
|
||||
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
|
||||
|
||||
assert isinstance(config, expected_type)
|
||||
|
||||
for param_name, param_value in expected_params.items():
|
||||
if isinstance(param_value, (PerAxis, PerGroup)):
|
||||
if isinstance(param_value, PerAxis):
|
||||
assert isinstance(getattr(config, param_name), PerAxis)
|
||||
assert getattr(config, param_name).axis == param_value.axis
|
||||
else:
|
||||
assert isinstance(getattr(config, param_name), PerGroup)
|
||||
assert (
|
||||
getattr(config, param_name).group_size == param_value.group_size
|
||||
)
|
||||
else:
|
||||
assert getattr(config, param_name) == param_value
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
|
||||
)
|
||||
@pytest.mark.parametrize("group_size", [4, 8])
|
||||
@pytest.mark.parametrize("quantize_embedding", [False, True])
|
||||
@require_torch_2_6_0
|
||||
def test_prepare_model_for_qat(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
): # pylint: disable=redefined-outer-name
|
||||
prepare_model_for_qat(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert hasattr(model.model.embed_tokens, "weight_fake_quantizer")
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
assert (
|
||||
child.activation_fake_quantizer.config.dtype
|
||||
== activation_dtype.value
|
||||
)
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, AffineQuantizedTensor
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
if activation_dtype:
|
||||
assert isinstance(
|
||||
child.weight, LinearActivationQuantizedTensor
|
||||
), "Linear weight should be quantized with activation quantization"
|
||||
else:
|
||||
assert isinstance(
|
||||
child.weight, AffineQuantizedTensor
|
||||
), "Linear weight should be quantized without activation quantization"
|
||||
|
||||
|
||||
class TestQuantizationCallback:
|
||||
"""
|
||||
Test QATCallback
|
||||
"""
|
||||
|
||||
@pytest.fixture()
|
||||
def trainer_state(self):
|
||||
return TrainerState(
|
||||
global_step=0,
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
fake_quant_after_n_steps=100,
|
||||
)
|
||||
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
cfg.weight_dtype,
|
||||
cfg.group_size,
|
||||
cfg.activation_dtype,
|
||||
cfg.quantize_embedding,
|
||||
)
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
|
||||
# simulate first training step
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been disabled
|
||||
assert not model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert not model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
trainer_state.global_step = 100
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should have been enabled
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
fake_quant_after_n_steps=None,
|
||||
)
|
||||
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
cfg.weight_dtype,
|
||||
cfg.group_size,
|
||||
cfg.activation_dtype,
|
||||
cfg.quantize_embedding,
|
||||
)
|
||||
|
||||
# ensure model has been quantized
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
qat_callback = QATCallback(cfg)
|
||||
# simulate first training step
|
||||
qat_callback.on_step_begin(
|
||||
args=None,
|
||||
state=trainer_state,
|
||||
control=None,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
|
||||
class TestConvertQATModelForPTQ:
|
||||
"""
|
||||
Test convert_qat_model_for_ptq
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_convert_qat_model_for_ptq(
|
||||
self, model
|
||||
): # pylint: disable=redefined-outer-name
|
||||
config = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
@@ -2,8 +2,6 @@
|
||||
E2E tests for qwen
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -13,9 +11,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.qwen")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestE2eQwen:
|
||||
"""
|
||||
@@ -64,6 +59,7 @@ class TestE2eQwen:
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
"gradient_checkpointing": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for reward model lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"""
|
||||
@@ -64,15 +58,15 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"gradient_checkpointing": True,
|
||||
"warmup_ratio": 0.1,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
102
tests/e2e/test_save_first_step.py
Normal file
102
tests/e2e/test_save_first_step.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
E2E tests for relora llama
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestSaveFirstStepCallback(unittest.TestCase):
|
||||
"""Test cases for save_first_step callback config."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_save_first_step(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_no_save_first_step(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
with pytest.raises(AssertionError):
|
||||
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
|
||||
@@ -2,11 +2,8 @@
|
||||
E2E tests for custom schedulers using Llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,9 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestCustomSchedulers(unittest.TestCase):
|
||||
"""
|
||||
@@ -57,13 +51,13 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
"lr_scheduler": "rex",
|
||||
"warmup_steps": 5,
|
||||
"cosine_min_lr_ratio": 0.05,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -10,8 +10,6 @@ from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# from importlib.metadata import version
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
@@ -79,6 +77,18 @@ def require_torch_2_6_0(test_case):
|
||||
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_7_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.7.0
|
||||
"""
|
||||
|
||||
def is_min_2_7_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.7.0")
|
||||
|
||||
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch < 2.6.0
|
||||
@@ -132,6 +142,10 @@ def is_hopper():
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
|
||||
def require_hopper(test_case):
|
||||
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
) -> None:
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
config validation tests for swiglu args
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -12,6 +10,7 @@ from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture(name="minimal_liger_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
@@ -41,7 +40,7 @@ class TestValidation:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
caplog.set_level("WARNING")
|
||||
self._caplog = caplog
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_liger_cfg):
|
||||
@@ -52,9 +51,7 @@ class TestValidation:
|
||||
| minimal_liger_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(
|
||||
logging.WARNING, logger="axolotl.integrations.liger.args"
|
||||
):
|
||||
with self._caplog.at_level("WARNING", logger="axolotl.integrations.liger.args"):
|
||||
prepare_plugins(test_cfg)
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
# TODO this test is brittle in CI
|
||||
|
||||
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Unit tests for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||
check_create_accelerate_code_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerAcceleratorArgs(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
def test_check_create_accelerate_code_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable.
|
||||
This will fail if the patched code changes upstream.
|
||||
"""
|
||||
assert check_create_accelerate_code_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
28
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
28
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Unit tests for trainer loss calc monkeypatch."""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
check_evaluation_loop_is_fsdp2_patchable,
|
||||
check_evaluation_loop_is_patchable,
|
||||
check_maybe_log_save_evaluate_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerLossCalc(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer loss calc monkeypatch
|
||||
"""
|
||||
|
||||
def test_trainer_loss_calc_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable. This will fail if
|
||||
the patched code changes upstream.
|
||||
"""
|
||||
assert check_evaluation_loop_is_patchable()
|
||||
assert check_evaluation_loop_is_fsdp2_patchable()
|
||||
assert check_maybe_log_save_evaluate_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,6 @@
|
||||
# pylint: disable=too-many-lines
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
@@ -80,7 +79,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(test_cfg)
|
||||
assert (
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
@@ -218,7 +217,7 @@ class TestValidation(BaseValidation):
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert "batch_size is not recommended" in self._caplog.records[0].message
|
||||
|
||||
@@ -513,7 +512,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
||||
@@ -531,7 +530,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"probably set bfloat16 or float16" in record.message
|
||||
@@ -577,7 +576,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
@@ -595,7 +594,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"adamw hyperparameters found, but no adamw optimizer set"
|
||||
@@ -654,7 +653,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
@@ -673,7 +672,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.INFO):
|
||||
with self._caplog.at_level("INFO"):
|
||||
cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"Setting `pad_to_sequence_len: true` to prevent memory leaks when sample_packing"
|
||||
@@ -693,7 +692,7 @@ class TestValidation(BaseValidation):
|
||||
"bf16": True,
|
||||
"capabilities": {"bf16": False},
|
||||
"env_capabilities": {
|
||||
"torch_version": "2.5.1",
|
||||
"torch_version": "2.6.0",
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1109,7 +1108,7 @@ class TestValidation(BaseValidation):
|
||||
def test_hub_model_id_save_value_warns_save_stragey_no(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": "no"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
@@ -1118,7 +1117,7 @@ class TestValidation(BaseValidation):
|
||||
DictDefault({"hub_model_id": "test", "save_strategy": "test"}) | minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
@@ -1128,7 +1127,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
@@ -1138,28 +1137,28 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_none(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test", "save_strategy": None}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_hub_model_id_save_value_no_set_save_strategy(self, minimal_cfg):
|
||||
cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
validate_config(cfg)
|
||||
assert len(self._caplog.records) == 0
|
||||
|
||||
def test_dpo_beta_deprecation(self, minimal_cfg):
|
||||
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg["rl_beta"] == 0.2
|
||||
assert new_cfg["dpo_beta"] is None
|
||||
@@ -1175,7 +1174,7 @@ class TestValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.eval_strategy == "steps"
|
||||
assert (
|
||||
@@ -1203,7 +1202,7 @@ class TestValidation(BaseValidation):
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
env_capabilities = {"torch_version": "2.6.0"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
@@ -1245,7 +1244,7 @@ class TestTorchCompileValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
env_capabilities = {"torch_version": "2.6.0"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
@@ -1455,7 +1454,7 @@ class TestValidationWandb(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level("WARNING"):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
@@ -1505,7 +1504,6 @@ class TestValidationWandb(BaseValidation):
|
||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
os.environ.pop("WANDB_PROJECT", None)
|
||||
os.environ.pop("WANDB_NAME", None)
|
||||
@@ -1514,16 +1512,12 @@ class TestValidationWandb(BaseValidation):
|
||||
os.environ.pop("WANDB_MODE", None)
|
||||
os.environ.pop("WANDB_WATCH", None)
|
||||
os.environ.pop("WANDB_LOG_MODEL", None)
|
||||
os.environ.pop("WANDB_DISABLED", None)
|
||||
|
||||
def test_wandb_set_disabled(self, minimal_cfg):
|
||||
cfg = DictDefault({}) | minimal_cfg
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(new_cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||
assert new_cfg.use_wandb is None
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
@@ -1535,13 +1529,10 @@ class TestValidationWandb(BaseValidation):
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(new_cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
assert new_cfg.use_wandb is True
|
||||
|
||||
os.environ.pop("WANDB_PROJECT", None)
|
||||
os.environ.pop("WANDB_DISABLED", None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(is_comet_available() is False, reason="comet_ml is not installed")
|
||||
@@ -1699,3 +1690,18 @@ class TestValidationMLflow(BaseValidation):
|
||||
assert new_cfg.use_mlflow is True
|
||||
|
||||
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
||||
|
||||
|
||||
class TestDataloaderValidation(BaseValidation):
|
||||
"""
|
||||
tests for dataloader_* sane defaults
|
||||
"""
|
||||
|
||||
def test_dataloader_auto_defaults(self, minimal_cfg):
|
||||
cfg = minimal_cfg
|
||||
|
||||
new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"})
|
||||
|
||||
assert new_cfg.dataloader_num_workers == 8
|
||||
assert new_cfg.dataloader_pin_memory is True
|
||||
assert new_cfg.dataloader_prefetch_factor == 256
|
||||
|
||||
@@ -143,6 +143,12 @@ def fixture_phi35_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi4_tokenizer", scope="session", autouse=True)
|
||||
def fixture_phi4_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-reasoning")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
|
||||
def fixture_gemma2_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
|
||||
@@ -150,6 +156,30 @@ def fixture_gemma2_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="magistral_tokenizer")
|
||||
def fixture_magistral_tokenizer():
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Magistral-Small-2506")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="devstral_tokenizer")
|
||||
def fixture_devstral_tokenizer():
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2505")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="devstral_1_1_tokenizer")
|
||||
def fixture_devstral_1_1_tokenizer():
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
tokenizer = HFMistralTokenizer.from_pretrained("mistralai/Devstral-Small-2507")
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
||||
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
||||
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
||||
|
||||
@@ -3,14 +3,13 @@ tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
from axolotl.prompt_strategies.messages.chat import load
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__, log_level="DEBUG")
|
||||
|
||||
|
||||
class TestMessagesChatLlama3:
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Tests for chat template prompt strategy with schema unification for none fields
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import StrategyLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="messages_w_tools")
|
||||
def fixture_messages_w_tools():
|
||||
jsons = """
|
||||
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
rows = [json.loads(row) for row in jsons]
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_prompt_strategy")
|
||||
def qwen3_chat_template_strategy(qwen3_tokenizer):
|
||||
cfg = DictDefault(
|
||||
sequence_len=2048,
|
||||
chat_template="qwen3",
|
||||
eot_tokens=["<|im_end|>"],
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
type="chat_template",
|
||||
)
|
||||
load = StrategyLoader()
|
||||
strat = load(qwen3_tokenizer, cfg, ds_cfg)
|
||||
return strat
|
||||
|
||||
|
||||
class TestSchemaUnification:
|
||||
"""
|
||||
Test class on handling null fields for tool calling
|
||||
"""
|
||||
|
||||
def test_schema_unification_single_prompt(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
for row in messages_w_tools:
|
||||
inputs = qwen3_prompt_strategy.tokenize_prompt(row)
|
||||
decoded = qwen3_tokenizer.decode(inputs["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
|
||||
def test_schema_unification_batched(
|
||||
self, messages_w_tools, qwen3_prompt_strategy, qwen3_tokenizer
|
||||
):
|
||||
rows = messages_w_tools.map(qwen3_prompt_strategy.tokenize_prompt, batched=True)
|
||||
for row in rows:
|
||||
decoded = qwen3_tokenizer.decode(row["input_ids"])
|
||||
tool_call = decoded.split("<tool_call>")[-1].split("</tool_call>")[0]
|
||||
assert '"message": null' not in tool_call
|
||||
assert '"theta": null' not in tool_call
|
||||
@@ -2,7 +2,6 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
@@ -13,9 +12,9 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TestAssistantChatTemplateLlama3:
|
||||
|
||||
@@ -4,7 +4,6 @@ tests for chat_template prompt strategy
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -18,11 +17,11 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
PARAMETRIZE_KEYS = "tokenizer, chat_template, chat_template_jinja, eos_token"
|
||||
PARAMETRIZE_PARAMS = [
|
||||
@@ -34,15 +33,14 @@ PARAMETRIZE_PARAMS = [
|
||||
"mistralv03_tokenizer_chat_template_jinja",
|
||||
"[/INST]",
|
||||
),
|
||||
# TODO: temporarily skip gemma due to gemma3 template
|
||||
# Re-enable on new chat_template implementation for perf
|
||||
# (
|
||||
# "gemma2_tokenizer",
|
||||
# "jinja",
|
||||
# "gemma2_tokenizer_chat_template_jinja",
|
||||
# "<end_of_turn>",
|
||||
# ),
|
||||
(
|
||||
"gemma2_tokenizer",
|
||||
"jinja",
|
||||
"gemma2_tokenizer_chat_template_jinja",
|
||||
"<end_of_turn>",
|
||||
),
|
||||
("phi35_tokenizer", "phi_35", None, "<|end|>"),
|
||||
("phi4_tokenizer", "phi_4", None, "<|im_end|>"),
|
||||
]
|
||||
|
||||
|
||||
@@ -96,11 +94,7 @@ class TestChatTemplateConfigurations:
|
||||
if (
|
||||
turn_idx == 0
|
||||
and turn.get("from") in ["system", "context"]
|
||||
and (
|
||||
"mistral" in tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in tokenizer.name_or_path.lower() # temporarily skip gemma due to gemma3 template
|
||||
)
|
||||
and ("mistral" in tokenizer.name_or_path.lower())
|
||||
):
|
||||
assert (
|
||||
start_idx == -1 and end_idx == -1
|
||||
@@ -936,36 +930,14 @@ class TestChatTemplateConfigurations:
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
# Special case for Mistral with additional tool variables
|
||||
if chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
expected_variables = {"role", "content", "tool_call_id", "tool_calls"}
|
||||
# Most chat templates use the standard role and content variables
|
||||
elif chat_template in ["llama3", "chatml", "phi_35", "phi_4"] or (
|
||||
chat_template == "jinja" and tokenizer == "gemma2_tokenizer"
|
||||
):
|
||||
expected_variables = {"role", "content"}
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
@@ -974,6 +946,12 @@ class TestChatTemplateConfigurations:
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
assert variables == expected_variables, (
|
||||
f"Expected variables: {expected_variables} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
|
||||
def test_eot_tokens_conflict_with_eos_token(
|
||||
self,
|
||||
tokenizer,
|
||||
@@ -1281,3 +1259,162 @@ class TestChatTemplateConfigurations:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
|
||||
|
||||
class TestChatTemplateToolCalling:
|
||||
"""
|
||||
Test class for tool calling functionality with chat templates.
|
||||
"""
|
||||
|
||||
def test_tool_calling_with_llama4_template(
|
||||
self,
|
||||
llama3_tokenizer,
|
||||
):
|
||||
LOG.info("Testing tool calling with llama3 tokenizer and llama4 chat template")
|
||||
|
||||
# Create tool calling dataset
|
||||
tool_calling_dataset = [
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xml_escape",
|
||||
"description": 'Replaces any "<", ">", or "&" characters in the input string with their corresponding XML entities.',
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"s": {
|
||||
"type": "string",
|
||||
"description": "The input string to be XML-escaped.",
|
||||
}
|
||||
},
|
||||
"required": ["s"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "The number to find multiples of.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "The upper limit for the multiples.",
|
||||
},
|
||||
},
|
||||
"required": ["number", "limit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you help me find multiples of 5 that are less than 20?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"arguments": {
|
||||
"number": 5,
|
||||
"limit": 20,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "name": "multiples", "content": "5,10,15"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The multiples of 5 less than 20 are: 5, 10, and 15.",
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Setup tokenizer with llama4 chat template
|
||||
tokenizer = deepcopy(llama3_tokenizer)
|
||||
|
||||
# Add EOS token to the tokenizer
|
||||
eot_token = "<|eot_id|>"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [eot_token]})
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_template=get_chat_template("llama4"),
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
field_messages="messages",
|
||||
field_tools="tools",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
eot_tokens=[eot_token],
|
||||
)
|
||||
|
||||
res = strategy.tokenize_prompt(tool_calling_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
|
||||
# Verify that the input_ids contain expected tokens
|
||||
assert len(input_ids) > 0, "Input IDs should not be empty"
|
||||
assert len(labels) == len(input_ids), "Labels should match input_ids length"
|
||||
|
||||
# Decode the full conversation to verify structure
|
||||
decoded_conversation = tokenizer.decode(input_ids)
|
||||
|
||||
# Verify tool calling structure is present in the decoded conversation
|
||||
assert (
|
||||
'"type": "function",' in decoded_conversation
|
||||
), "Tool type function should be in conversation"
|
||||
assert (
|
||||
'"name": "multiples",' in decoded_conversation
|
||||
), "Tool function name should be in conversation"
|
||||
|
||||
assert (
|
||||
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||
in decoded_conversation
|
||||
), "Assistant tool call should be in conversation"
|
||||
assert (
|
||||
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||
), "IPython header should be in conversation"
|
||||
assert (
|
||||
'"5,10,15"' in decoded_conversation
|
||||
), "Tool response should be in conversation"
|
||||
|
||||
# Get conversation turns to verify labeling
|
||||
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||
tool_calling_dataset[0]
|
||||
)
|
||||
|
||||
# Check that assistant responses are properly labeled
|
||||
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||
if turn["role"] == "assistant":
|
||||
start_idx, end_idx = strategy.find_turn(
|
||||
turns=turns, turn_idx=i, tools=tools
|
||||
)
|
||||
|
||||
assert (
|
||||
start_idx != -1 and end_idx != -1
|
||||
), f"Assistant turn {i} should be found"
|
||||
|
||||
# Verify that assistant responses have proper labels
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||
), f"Assistant turn {i} should be unmasked"
|
||||
|
||||
851
tests/prompt_strategies/test_chat_templates_mistral.py
Normal file
851
tests/prompt_strategies/test_chat_templates_mistral.py
Normal file
@@ -0,0 +1,851 @@
|
||||
"""Test chat templates for mistral-common wrapper tokenizer"""
|
||||
|
||||
import unittest
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.mistral import HFMistralTokenizer
|
||||
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
("tokenizer_str", "assistant_toolcall_ids", "tool_result_ids"),
|
||||
(
|
||||
("magistral_tokenizer", (9, 44627, 3684, 33, 19881, 1049, 1050, 1051, 1052, 1053, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
|
||||
("devstral_tokenizer", (9, 1091, 19227, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 61906, 2811, 16753, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 4179, 1429, 1327, 2811, 1429, 19881, 1049, 1050, 1051, 1052, 1053, 1034, 27028, 2), (7, 19881, 1049, 1050, 1051, 1052, 1053, 19, 1049, 1044, 1050, 8)),
|
||||
("devstral_1_1_tokenizer", (9, 44627, 3684, 32, 19227, 12856, 2811, 1032, 1049, 1054, 1044, 1429, 33319, 2811, 1032, 1050, 1125, 2,), (7, 1049, 1044, 1050, 8)),
|
||||
)
|
||||
)
|
||||
# fmt: on
|
||||
def test_mistral_chat_template(
|
||||
tokenizer_str: str,
|
||||
assistant_toolcall_ids: tuple[int, ...],
|
||||
tool_result_ids: tuple[int, ...],
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
"""Test chat template with the Magistral/Devstral tokenizer"""
|
||||
# pylint: disable=duplicate-code
|
||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||
|
||||
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
|
||||
|
||||
# check bos, eos, pad, unk are accessible properties
|
||||
assert tokenizer.bos_token_id == 1
|
||||
assert tokenizer.eos_token_id == 2
|
||||
assert tokenizer.pad_token_id == 11
|
||||
assert tokenizer.unk_token_id == 0
|
||||
|
||||
assert tokenizer.pad_token == "<pad>"
|
||||
assert tokenizer.eos_token == "</s>"
|
||||
assert tokenizer.bos_token == "<s>"
|
||||
assert tokenizer.unk_token == "<unk>"
|
||||
|
||||
strategy = MistralStrategy(
|
||||
MistralPrompter(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="turn",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# test chat template masking without system prompt
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
3, # [INST]
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
4, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, # [INST]
|
||||
-100, # Hello
|
||||
-100, # ,
|
||||
-100, # how
|
||||
-100, # are
|
||||
-100, # you
|
||||
-100, # ?
|
||||
-100, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
# test chat template masking with system prompt
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
17, # [SYSTEM_PROMPT]
|
||||
4568, # You
|
||||
1584, # are
|
||||
1261, # a
|
||||
20351, # helpful
|
||||
27089, # assistant
|
||||
1046, # .
|
||||
18, # [/SYSTEM_PROMPT]
|
||||
3, # [INST]
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
4, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, # [SYSTEM_PROMPT]
|
||||
-100, # You
|
||||
-100, # are
|
||||
-100, # a
|
||||
-100, # helpful
|
||||
-100, # assistant
|
||||
-100, # .
|
||||
-100, # [/SYSTEM_PROMPT]
|
||||
-100, # [INST]
|
||||
-100, # Hello
|
||||
-100, # ,
|
||||
-100, # how
|
||||
-100, # are
|
||||
-100, # you
|
||||
-100, # ?
|
||||
-100, # [/INST]
|
||||
1073, # I
|
||||
4525, # 'm
|
||||
6965, # doing
|
||||
4824, # great
|
||||
1044, # ,
|
||||
15412, # thank
|
||||
1636, # you
|
||||
1033, # !
|
||||
2, # </s>
|
||||
]
|
||||
|
||||
# test chat template with tools
|
||||
res = strategy.tokenize_prompt(
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"description": "Generates a list of all the multiples of a number that are less than a given limit.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "The number to find multiples of.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "The upper limit for the multiples.",
|
||||
},
|
||||
},
|
||||
"required": ["number", "limit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, can you give me a breakdown of how to throw an awesome themed party? Like, what themes work best, and how can I set everything up to really wow my guests? I want some ideas on decorations, food, and activities that will make the party unforgettable!",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call12345",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiples",
|
||||
"arguments": {
|
||||
"number": 16,
|
||||
"limit": 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call12345",
|
||||
"name": "multiples",
|
||||
"content": "1,2",
|
||||
},
|
||||
{"role": "assistant", "content": "The multiples of 16 is 1 and 2."},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
assert res["input_ids"] == [
|
||||
1, # bos
|
||||
5, 1091, 19227, 4994, 2811, 1429, 5165, 1897, 1429, 5165, 2811, 16753, 2391, 2811, 1429, 44627, 3684, 1897, 1429, 14653, 2811, 1429, 10639, 2130, 1261, 2951, 1307, 1747, 1278, 60092, 1307, 1261, 2782, 1455, 1584, 4289, 2224, 1261, 4265, 6139, 39249, 1429, 26204, 2811, 16753, 4994, 2811, 1429, 6371, 1897, 1429, 48649, 2811, 16753, 12856, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 2782, 1317, 3081, 60092, 1307, 2613, 4179, 1429, 33319, 2811, 16753, 4994, 2811, 1429, 49039, 1897, 1429, 14653, 2811, 1429, 1784, 9229, 6139, 1394, 1278, 60092, 2613, 47579, 1429, 15760, 2811, 12161, 12856, 1897, 1429, 33319, 4964, 2821, 27028, 6, # tool prompt
|
||||
3, 46634, 1044, 1710, 1636, 5628, 1639, 1261, 44433, 1307, 2606, 1317, 5388, 1420, 54191, 2424, 1286, 8967, 1063, 15621, 1044, 2549, 30305, 2196, 3560, 1044, 1321, 2606, 1710, 1362, 2016, 8605, 2015, 1317, 5524, 118931, 2036, 32951, 1063, 1362, 2933, 2269, 12106, 1408, 101987, 1044, 6939, 1044, 1321, 9216, 1455, 2084, 3180, 1278, 8967, 119141, 1689, 5935, 1033, 4, # user
|
||||
*assistant_toolcall_ids, # assistant tool calling
|
||||
*tool_result_ids, # tool result
|
||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||
2 # eos
|
||||
]
|
||||
|
||||
assert res["labels"] == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool prompt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # user prompt
|
||||
*assistant_toolcall_ids, # assistant tool calling
|
||||
*([-100] * len(tool_result_ids)), # tool result
|
||||
1784, 60092, 1307, 1032, 1049, 1054, 1395, 1032, 1049, 1321, 1032, 1050, 1046, # assistant
|
||||
2 # eos
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# test chat template with tokenize=False
|
||||
res = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great, thank you!"},
|
||||
],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
assert res == "<s>[INST]Hello, how are you?[/INST]I'm doing great, thank you!</s>"
|
||||
|
||||
# test encode
|
||||
res = tokenizer.encode("Hello, how are you?", add_special_tokens=True)
|
||||
assert res == [
|
||||
1, # bos
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
2, # eos
|
||||
]
|
||||
|
||||
# test decode no skip special tokens
|
||||
decoded_res = tokenizer.decode(res, skip_special_tokens=False)
|
||||
|
||||
assert decoded_res == "<s>Hello, how are you?</s>"
|
||||
|
||||
# test decode skip special tokens
|
||||
decoded_res = tokenizer.decode(res, skip_special_tokens=True)
|
||||
assert decoded_res == "Hello, how are you?"
|
||||
|
||||
# test encode no special tokens
|
||||
res = tokenizer.encode("Hello, how are you?", add_special_tokens=False)
|
||||
assert res == [
|
||||
22177, # Hello
|
||||
1044, # ,
|
||||
2606, # how
|
||||
1584, # are
|
||||
1636, # you
|
||||
1063, # ?
|
||||
]
|
||||
|
||||
# test convert ids to tokens
|
||||
res = tokenizer.convert_ids_to_tokens(res)
|
||||
# spacing are needed as we are converting without decoding
|
||||
assert res == ["Hello", ",", " how", " are", " you", "?"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
|
||||
def test_magistral_tokenizer_pad_method(magistral_tokenizer: "HFMistralTokenizer"):
|
||||
"""Test the MistralTokenizer pad method"""
|
||||
from axolotl.utils.collators.core import IGNORE_INDEX
|
||||
|
||||
magistral_pad_token_id = 11 # taken from tokenizer.pad_token_id
|
||||
|
||||
# Test padding with input_ids and labels only
|
||||
features = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
||||
{"input_ids": [7, 8], "labels": [9, 10]},
|
||||
]
|
||||
|
||||
result = magistral_tokenizer.pad(features, padding=True, return_tensors="pt")
|
||||
|
||||
# Check that input_ids are padded correctly
|
||||
assert result["input_ids"].shape == (2, 3)
|
||||
assert result["input_ids"].tolist() == [[1, 2, 3], [7, 8, magistral_pad_token_id]]
|
||||
|
||||
# Check that labels are padded correctly
|
||||
assert result["labels"].shape == (2, 3)
|
||||
assert result["labels"].tolist() == [[4, 5, 6], [9, 10, IGNORE_INDEX]]
|
||||
|
||||
# Check that attention_mask and position_ids are NOT created
|
||||
assert "attention_mask" not in result
|
||||
assert "position_ids" not in result
|
||||
|
||||
# Test padding with attention_mask
|
||||
features_with_attention = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "attention_mask": [1, 1, 1]},
|
||||
{"input_ids": [7, 8], "labels": [9, 10], "attention_mask": [1, 1]},
|
||||
]
|
||||
|
||||
result = magistral_tokenizer.pad(
|
||||
features_with_attention, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check that attention_mask is padded correctly
|
||||
assert result["attention_mask"].shape == (2, 3)
|
||||
assert result["attention_mask"].tolist() == [[1, 1, 1], [1, 1, 0]]
|
||||
|
||||
# Test padding with position_ids
|
||||
features_with_position = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "position_ids": [0, 1, 2]},
|
||||
{"input_ids": [7, 8], "labels": [9, 10], "position_ids": [0, 1]},
|
||||
]
|
||||
|
||||
result = magistral_tokenizer.pad(
|
||||
features_with_position, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check that position_ids are padded correctly (continuing sequence)
|
||||
assert result["position_ids"].shape == (2, 3)
|
||||
assert result["position_ids"].tolist() == [[0, 1, 2], [0, 1, 2]]
|
||||
|
||||
# Test padding with all fields
|
||||
features_all = [
|
||||
{
|
||||
"input_ids": [1, 2, 3],
|
||||
"labels": [4, 5, 6],
|
||||
"attention_mask": [1, 1, 1],
|
||||
"position_ids": [0, 1, 2],
|
||||
},
|
||||
{
|
||||
"input_ids": [7, 8],
|
||||
"labels": [9, 10],
|
||||
"attention_mask": [1, 1],
|
||||
"position_ids": [0, 1],
|
||||
},
|
||||
]
|
||||
|
||||
result = magistral_tokenizer.pad(features_all, padding=True, return_tensors="pt")
|
||||
|
||||
# All fields should be present and correctly padded
|
||||
assert "input_ids" in result
|
||||
assert "labels" in result
|
||||
assert "attention_mask" in result
|
||||
assert "position_ids" in result
|
||||
|
||||
# Test padding with all sequences same length
|
||||
features_same_length = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6]},
|
||||
{"input_ids": [7, 8, 9], "labels": [10, 11, 12]},
|
||||
]
|
||||
|
||||
result = magistral_tokenizer.pad(
|
||||
features_same_length, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check match when no padding is needed
|
||||
assert result["input_ids"][0].tolist() == features_same_length[0]["input_ids"]
|
||||
assert result["labels"][0].tolist() == features_same_length[0]["labels"]
|
||||
|
||||
assert result["input_ids"][1].tolist() == features_same_length[1]["input_ids"]
|
||||
assert result["labels"][1].tolist() == features_same_length[1]["labels"]
|
||||
|
||||
# Test padding with max_length parameter
|
||||
result = magistral_tokenizer.pad(
|
||||
features, padding="max_length", max_length=5, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Should pad to max_length
|
||||
assert result["input_ids"].shape == (2, 5)
|
||||
assert result["labels"].shape == (2, 5)
|
||||
|
||||
# Test numpy return type
|
||||
result = magistral_tokenizer.pad(features, padding=True, return_tensors="np")
|
||||
|
||||
# Should return numpy arrays
|
||||
import numpy as np
|
||||
|
||||
assert isinstance(result["input_ids"], np.ndarray)
|
||||
assert isinstance(result["labels"], np.ndarray)
|
||||
|
||||
# Test unsupported field rejection
|
||||
features_unsupported = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "unsupported_field": [7, 8, 9]},
|
||||
]
|
||||
|
||||
with pytest.raises(NotImplementedError, match="unsupported_field"):
|
||||
magistral_tokenizer.pad(features_unsupported, padding=True, return_tensors="pt")
|
||||
|
||||
# Test token_type_ids rejection
|
||||
features_token_type = [
|
||||
{"input_ids": [1, 2, 3], "labels": [4, 5, 6], "token_type_ids": [0, 0, 0]},
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="token_type_ids is not supported"):
|
||||
magistral_tokenizer.pad(features_token_type, padding=True, return_tensors="pt")
|
||||
|
||||
|
||||
def test_magistral_tool_calling(magistral_tokenizer: "HFMistralTokenizer"):
|
||||
"""Test tool calling with the Magistral tokenizer"""
|
||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||
|
||||
strategy = MistralStrategy(
|
||||
MistralPrompter(
|
||||
magistral_tokenizer,
|
||||
chat_template=None,
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=magistral_tokenizer,
|
||||
train_on_inputs=False,
|
||||
train_on_eos="turn",
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
|
||||
# Test basic tool calling with single function
|
||||
basic_tool_calling = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in San Francisco?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call12345",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call12345",
|
||||
"name": "get_weather",
|
||||
"content": "Sunny, 72°F",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The weather in San Francisco is sunny and 72°F.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
res = strategy.tokenize_prompt(basic_tool_calling)
|
||||
|
||||
# Basic validation
|
||||
assert "input_ids" in res
|
||||
assert "labels" in res
|
||||
assert len(res["input_ids"]) > 0
|
||||
assert len(res["labels"]) == len(res["input_ids"])
|
||||
|
||||
# Decode and verify structure
|
||||
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||
assert (
|
||||
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}}, "required": ["location"]}}}][/AVAILABLE_TOOLS]'
|
||||
in decoded
|
||||
)
|
||||
assert (
|
||||
'[TOOL_CALLS]get_weather[CALL_ID]call12345[ARGS]{"location": "San Francisco, CA"}</s>'
|
||||
in decoded
|
||||
)
|
||||
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]Sunny, 72°F[/TOOL_RESULTS]" in decoded
|
||||
assert "The weather in San Francisco is sunny and 72°F.</s>" in decoded
|
||||
|
||||
# Test multiple tool calls in sequence
|
||||
multi_tool_calling = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_numbers",
|
||||
"description": "Add two numbers together",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply_numbers",
|
||||
"description": "Multiply two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"x": {"type": "number", "description": "First number"},
|
||||
"y": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["x", "y"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Add 5 and 3, then multiply the result by 2",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call12345",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_numbers",
|
||||
"arguments": {"a": 5, "b": 3},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call12345",
|
||||
"name": "add_numbers",
|
||||
"content": "8",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call23456",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply_numbers",
|
||||
"arguments": {"x": 8, "y": 2},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call23456",
|
||||
"name": "multiply_numbers",
|
||||
"content": "16",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
res = strategy.tokenize_prompt(multi_tool_calling)
|
||||
|
||||
# Validation
|
||||
assert len(res["input_ids"]) > 0
|
||||
assert len(res["labels"]) == len(res["input_ids"])
|
||||
|
||||
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||
assert (
|
||||
'<s>[AVAILABLE_TOOLS][{"type": "function", "function": {"name": "add_numbers", "description": "Add two numbers together", "parameters": {"type": "object", "properties": {"a": {"type": "number", "description": "First number"}, "b": {"type": "number", "description": "Second number"}}, "required": ["a", "b"]}}}, {"type": "function", "function": {"name": "multiply_numbers", "description": "Multiply two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "First number"}, "y": {"type": "number", "description": "Second number"}}, "required": ["x", "y"]}}}][/AVAILABLE_TOOLS]'
|
||||
in decoded
|
||||
)
|
||||
assert (
|
||||
'[TOOL_CALLS]add_numbers[CALL_ID]call12345[ARGS]{"a": 5, "b": 3}</s>' in decoded
|
||||
)
|
||||
assert "[TOOL_RESULTS]call12345[TOOL_CONTENT]8[/TOOL_RESULTS]" in decoded
|
||||
assert (
|
||||
'[TOOL_CALLS]multiply_numbers[CALL_ID]call23456[ARGS]{"x": 8, "y": 2}</s>'
|
||||
in decoded
|
||||
)
|
||||
assert "[TOOL_RESULTS]call23456[TOOL_CONTENT]16[/TOOL_RESULTS]" in decoded
|
||||
assert (
|
||||
"The result is 16. I first added 5 and 3 to get 8, then multiplied 8 by 2 to get 16.</s>"
|
||||
in decoded
|
||||
)
|
||||
|
||||
# Test tool calling with system message
|
||||
system_tool_calling = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_database",
|
||||
"description": "Search for information in database",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant with access to a database.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Find information about Python programming",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "search123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_database",
|
||||
"arguments": {"query": "Python programming"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "search123",
|
||||
"name": "search_database",
|
||||
"content": "Python is a high-level programming language known for its simplicity.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Based on the database search, Python is a high-level programming language known for its simplicity and readability.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
res = strategy.tokenize_prompt(system_tool_calling)
|
||||
|
||||
# Validation
|
||||
assert len(res["input_ids"]) > 0
|
||||
assert len(res["labels"]) == len(res["input_ids"])
|
||||
|
||||
decoded = magistral_tokenizer.decode(res["input_ids"])
|
||||
|
||||
assert (
|
||||
'<s>[SYSTEM_PROMPT]You are a helpful assistant with access to a database.[/SYSTEM_PROMPT][AVAILABLE_TOOLS][{"type": "function", "function": {"name": "search_database", "description": "Search for information in database", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"]}}}][/AVAILABLE_TOOLS]'
|
||||
in decoded
|
||||
)
|
||||
|
||||
# Test error handling - missing tool response
|
||||
incomplete_tool_calling = {
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get current time",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What time is it?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "time12345",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"arguments": {},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current time is 12:00 PM.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
from mistral_common.exceptions import InvalidMessageStructureException
|
||||
|
||||
try:
|
||||
strategy.tokenize_prompt(incomplete_tool_calling)
|
||||
except InvalidMessageStructureException as e:
|
||||
assert "Not the same number of function calls and responses" in str(e)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO, fix for new HF wrapper call")
|
||||
def test_magistral_tokenizer_call_method(
|
||||
magistral_tokenizer: "HFMistralTokenizer", llama3_tokenizer: "PreTrainedTokenizer"
|
||||
):
|
||||
"""Test the __call__ method behavior matches HuggingFace standards"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
hf_tokenizer = deepcopy(llama3_tokenizer)
|
||||
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
||||
|
||||
test_text = "Hello, how are you?"
|
||||
batch_texts = ["Hello world", "How are you?"]
|
||||
|
||||
# Test single string with return_tensors=None
|
||||
hf_result: dict[str, list[int]] = hf_tokenizer(test_text, return_tensors=None)
|
||||
mistral_result: dict[str, list[int]] = magistral_tokenizer(
|
||||
test_text, return_tensors=None
|
||||
)
|
||||
|
||||
assert isinstance(mistral_result, dict)
|
||||
assert set(mistral_result.keys()) == {"input_ids", "attention_mask"}
|
||||
assert isinstance(mistral_result["input_ids"], type(hf_result["input_ids"])) # list
|
||||
assert isinstance(
|
||||
mistral_result["attention_mask"], type(hf_result["attention_mask"])
|
||||
)
|
||||
assert len(mistral_result["input_ids"]) == len(mistral_result["attention_mask"])
|
||||
assert np.all(mistral_result["attention_mask"])
|
||||
assert len(np.array(mistral_result["input_ids"]).shape) == 1 # 1D array
|
||||
|
||||
# Test single string with return_tensors='pt'
|
||||
hf_result_pt: dict[str, torch.Tensor] = hf_tokenizer(test_text, return_tensors="pt")
|
||||
mistral_result_pt: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
# Check structure and types
|
||||
assert isinstance(mistral_result_pt["input_ids"], torch.Tensor)
|
||||
assert isinstance(mistral_result_pt["attention_mask"], torch.Tensor)
|
||||
|
||||
# Check shapes match (don't compare token dimension)
|
||||
assert len(hf_result_pt["input_ids"].shape) == len(
|
||||
mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert hf_result_pt["input_ids"].shape[0] == mistral_result_pt["input_ids"].shape[0]
|
||||
assert (
|
||||
mistral_result_pt["attention_mask"].shape
|
||||
== mistral_result_pt["input_ids"].shape
|
||||
)
|
||||
assert torch.all(mistral_result_pt["attention_mask"] == 1)
|
||||
|
||||
# Test batch input with padding
|
||||
hf_batch: dict[str, torch.Tensor] = hf_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
mistral_batch: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
batch_texts, return_tensors="pt", padding=True
|
||||
)
|
||||
|
||||
# Check batch behavior
|
||||
assert len(hf_batch["input_ids"].shape) == len(mistral_batch["input_ids"].shape)
|
||||
assert hf_batch["input_ids"].shape[0] == mistral_batch["input_ids"].shape[0]
|
||||
assert mistral_batch["attention_mask"].shape == mistral_batch["input_ids"].shape
|
||||
assert torch.any(
|
||||
mistral_batch["attention_mask"][0] == 0
|
||||
) # padding in shorter sequence
|
||||
assert torch.all(
|
||||
mistral_batch["attention_mask"][1] == 1
|
||||
) # no padding in longer sequence
|
||||
|
||||
# Test numpy tensors
|
||||
mistral_result_np: dict[str, np.ndarray] = magistral_tokenizer(
|
||||
test_text, return_tensors="np"
|
||||
)
|
||||
assert isinstance(mistral_result_np["input_ids"], np.ndarray)
|
||||
assert isinstance(mistral_result_np["attention_mask"], np.ndarray)
|
||||
|
||||
# Test consistency with encode()
|
||||
encoded: list[int] = magistral_tokenizer.encode(test_text, add_special_tokens=True)
|
||||
called: dict[str, torch.Tensor] = magistral_tokenizer(
|
||||
test_text, return_tensors="pt"
|
||||
)
|
||||
assert encoded == called["input_ids"][0].tolist()
|
||||
|
||||
# Test Error handling
|
||||
with pytest.raises(ValueError, match="Unsupported kwargs"):
|
||||
magistral_tokenizer(test_text, unsupported_param=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="return_tensors='pt' or 'np' requires padding or truncation"
|
||||
):
|
||||
magistral_tokenizer(batch_texts, return_tensors="pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,8 +2,6 @@
|
||||
Tests for splitting reasoning/thinking from content into separate field
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
@@ -13,11 +11,6 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@pytest.fixture(name="messages_w_reasoning")
|
||||
def messages_w_reasoning_fixture():
|
||||
@@ -64,7 +57,6 @@ def messages_w_reasoning_fixture():
|
||||
|
||||
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
@enable_hf_offline
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
|
||||
@@ -103,7 +103,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
|
||||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
@@ -128,7 +128,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
|
||||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
@@ -169,7 +169,7 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
|
||||
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
@@ -199,7 +199,7 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
|
||||
@@ -6,8 +6,9 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
@@ -55,7 +56,8 @@ class TestDPOChatml:
|
||||
# test that dpo.load works
|
||||
load_dpo("chatml", cfg)
|
||||
# now actually load the datasets with the strategy
|
||||
train_ds, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_ds, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
assert train_ds[0]["prompt"].startswith("<|im_start|>")
|
||||
assert train_ds[0]["prompt"].endswith("<|im_start|>assistant\n")
|
||||
assert "chosen" in train_ds[0]
|
||||
|
||||
@@ -2,14 +2,12 @@
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__, log_level="DEBUG")
|
||||
|
||||
|
||||
class TestJinjaTemplateAnalyzer:
|
||||
|
||||
40
tests/test_chunked_xentropy.py
Normal file
40
tests/test_chunked_xentropy.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
test suite for chunked cross entropy
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chunked_fixtures():
|
||||
model_dim = 512
|
||||
vocab_size = 1024 * 256
|
||||
seq_len = 2048
|
||||
batch_size = 1
|
||||
|
||||
lm_head = nn.Linear(model_dim, vocab_size)
|
||||
hidden_state = torch.randn(batch_size, seq_len, model_dim)
|
||||
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
|
||||
return lm_head, hidden_state, labels, vocab_size
|
||||
|
||||
|
||||
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
|
||||
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
|
||||
lm_loss = get_causal_lm_loss()
|
||||
|
||||
logits = lm_head(hidden_state)
|
||||
|
||||
chunked_lm_loss = lm_loss(logits, labels)
|
||||
|
||||
logits_flattened = logits.view(-1, vocab_size)
|
||||
labels_flattened = labels.view(-1)
|
||||
|
||||
loss = nn.functional.cross_entropy(
|
||||
logits_flattened.float(), labels_flattened, reduction="mean"
|
||||
)
|
||||
|
||||
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Test dataset loading under various conditions.
|
||||
"""
|
||||
"""Test dataset loading under various conditions."""
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -12,8 +11,9 @@ from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -28,7 +28,9 @@ class TestDatasetPreparation:
|
||||
"""Test a configured dataloader."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self, tokenizer_huggyllama) -> PreTrainedTokenizer:
|
||||
def tokenizer(
|
||||
self, tokenizer_huggyllama
|
||||
) -> Generator[PreTrainedTokenizer, Any, Any]:
|
||||
tokenizer_huggyllama.add_special_tokens(SPECIAL_TOKENS)
|
||||
yield tokenizer_huggyllama
|
||||
|
||||
@@ -63,7 +65,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -107,7 +112,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -133,10 +141,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -145,7 +157,7 @@ class TestDatasetPreparation:
|
||||
|
||||
@enable_hf_offline
|
||||
def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
"""Usual use case. Verify a directory of parquet files can be loaded."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_dir = Path(tmp_dir) / "tmp_dataset"
|
||||
tmp_ds_dir.mkdir()
|
||||
@@ -168,10 +180,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -203,10 +219,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -232,10 +252,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -261,10 +285,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -286,10 +314,14 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
assert "conversation" not in train_dataset.features
|
||||
assert "chosen" in train_dataset.features
|
||||
assert "rejected" in train_dataset.features
|
||||
assert "prompt" in train_dataset.features
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
@@ -315,7 +347,10 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -335,20 +370,27 @@ class TestDatasetPreparation:
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
assert len(train_dataset) == 1800
|
||||
assert "conversation" in train_dataset.features
|
||||
assert "conversation" not in train_dataset.features
|
||||
assert "chosen" in train_dataset.features
|
||||
assert "rejected" in train_dataset.features
|
||||
assert "prompt" in train_dataset.features
|
||||
|
||||
@enable_hf_offline
|
||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||
@@ -387,16 +429,18 @@ class TestDatasetPreparation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||
"axolotl.utils.data.shared.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.return_value = (
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, prepared_path
|
||||
)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH",
|
||||
str(prepared_path),
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
@@ -428,10 +472,14 @@ class TestDatasetPreparation:
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
||||
with patch(
|
||||
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||
):
|
||||
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
|
||||
@@ -5,7 +5,6 @@ Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -14,8 +13,7 @@ from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -71,36 +69,14 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
self.expected_dataset = Dataset.from_dict(self.expected_data)
|
||||
|
||||
def test_deduplication(self):
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=self.dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_datasets_are_none(self):
|
||||
# Test when both datasets are None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
|
||||
def test_only_train_is_none(self):
|
||||
# Test when only train_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=None, eval_dataset=self.dataset
|
||||
)
|
||||
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
|
||||
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
|
||||
|
||||
def test_only_eval_is_none(self):
|
||||
# Test when only eval_dataset is None
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.dataset, eval_dataset=None
|
||||
)
|
||||
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
|
||||
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
|
||||
|
||||
def test_exact_duplicates(self):
|
||||
# Test when datasets are exact duplicates
|
||||
duplicate_data = {
|
||||
@@ -115,8 +91,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -139,8 +117,10 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset = Dataset.from_dict(expected_data)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
dataset=dataset, dataset_name="eval"
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
|
||||
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
|
||||
@@ -169,8 +149,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset, eval_dataset=dataset
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset, other_dataset=dataset
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -206,8 +186,8 @@ class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
|
||||
|
||||
# Run deduplication
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=dataset_train, eval_dataset=dataset_eval
|
||||
train_dataset, eval_dataset = deduplicate_and_log_datasets(
|
||||
dataset=dataset_train, other_dataset=dataset_eval
|
||||
)
|
||||
|
||||
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
|
||||
@@ -230,6 +210,7 @@ class TestDeduplicateRLDataset:
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||
],
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
yield fixture
|
||||
@@ -245,7 +226,9 @@ class TestDeduplicateRLDataset:
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -255,7 +238,8 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset has been deduplicated
|
||||
assert len(train_dataset) == 1800, "Dataset was not properly deduplicated"
|
||||
@@ -269,7 +253,9 @@ class TestDeduplicateRLDataset:
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
@@ -279,9 +265,10 @@ class TestDeduplicateRLDataset:
|
||||
]
|
||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||
|
||||
cfg.dataset_exact_deduplication = False
|
||||
# Load the dataset without deduplication
|
||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||
cfg.dataset_exact_deduplication = False
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
@@ -335,7 +322,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, _, _, _ = prepare_dataset(
|
||||
train_dataset, _, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -362,7 +349,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
_, eval_dataset, _, _ = prepare_dataset(
|
||||
_, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -389,7 +376,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Prepare dataset using the prepare_dataset function
|
||||
train_dataset, eval_dataset, _, _ = prepare_dataset(
|
||||
train_dataset, eval_dataset, _, _ = prepare_datasets(
|
||||
self.cfg_1,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -428,41 +415,8 @@ class TestWrongCollisions(unittest.TestCase):
|
||||
self.eval_dataset = Dataset.from_dict(self.eval_data)
|
||||
self.dataset = Dataset.from_dict(self.dataset_data)
|
||||
|
||||
@patch(
|
||||
"axolotl.utils.data.utils.sha256",
|
||||
side_effect=lambda x: (
|
||||
hashlib.sha256("forced_collision_hash".encode("utf-8")).hexdigest()
|
||||
if "sample 5" in x
|
||||
else hashlib.sha256(x.encode("utf-8")).hexdigest()
|
||||
),
|
||||
)
|
||||
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
|
||||
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_train),
|
||||
2,
|
||||
"train dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
2,
|
||||
"Eval dataset should not deduplicate rows with forced hash collisions but different labels.",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(dedup_eval),
|
||||
len(self.eval_dataset),
|
||||
"The output eval dataset should have the same number of rows as the input eval dataset.",
|
||||
)
|
||||
self.assertEqual(
|
||||
str(dedup_eval),
|
||||
str(self.eval_dataset),
|
||||
"The string representation of the output eval dataset should be identical to the input eval dataset.",
|
||||
)
|
||||
|
||||
def test_deduplication_dataset_only(self):
|
||||
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
|
||||
self.assertEqual(
|
||||
len(dedup_dataset), 3, "Dataset should have all original values"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import _get_parallel_config_kwargs
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
@@ -171,3 +172,42 @@ class TestModelsUtils:
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"world_size, tensor_parallel_size, context_parallel_size, dp_shard_size, dp_replicate_size, is_fsdp, expected",
|
||||
[
|
||||
(16, 2, 2, 2, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, None, True, (0, 0, 16, 1)),
|
||||
(16, 2, 2, 2, None, True, (2, 2, 2, 2)),
|
||||
(16, 2, 2, None, 2, True, (2, 2, 2, 2)),
|
||||
(16, 1, 1, None, 2, True, (0, 0, 8, 2)),
|
||||
(2, 1, 1, None, None, True, (0, 0, 2, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_parallel_config_kwargs(
|
||||
self,
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
dp_shard_size,
|
||||
dp_replicate_size,
|
||||
is_fsdp,
|
||||
)
|
||||
|
||||
if expected[0] > 1:
|
||||
assert res["tp_size"] == expected[0]
|
||||
if expected[1] > 1:
|
||||
assert res["cp_size"] == expected[1]
|
||||
if expected[2] > 1:
|
||||
assert res["dp_shard_size"] == expected[2]
|
||||
if expected[3] > 1:
|
||||
assert res["dp_replicate_size"] == expected[3]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user