Merge branch 'main' into telemetry-opt-in
This commit is contained in:
@@ -17,16 +17,23 @@ class BaseCliTest:
|
||||
command: Command to test (train/evaluate)
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, [command, "--no-accelerate"])
|
||||
result = cli_runner.invoke(cli, [command, "--launcher", "python"])
|
||||
assert result.exit_code != 0
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, [command, "nonexistent.yml", "--no-accelerate"])
|
||||
result = cli_runner.invoke(
|
||||
cli, [command, "nonexistent.yml", "--launcher", "python"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
def _test_basic_execution(
|
||||
self, cli_runner, tmp_path: Path, valid_test_config: str, command: str
|
||||
self,
|
||||
cli_runner,
|
||||
tmp_path: Path,
|
||||
valid_test_config: str,
|
||||
command: str,
|
||||
train: bool = True,
|
||||
):
|
||||
"""Test basic execution with accelerate.
|
||||
|
||||
@@ -35,24 +42,37 @@ class BaseCliTest:
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
command: Command to test (train/evaluate)
|
||||
train: Whether to test training (default) or evaluation
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock:
|
||||
mock_fn = "os.execvpe" if command == "train" else "subprocess.run"
|
||||
|
||||
with patch(mock_fn) as mock:
|
||||
result = cli_runner.invoke(cli, [command, str(config_path)])
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.args[0] == [
|
||||
|
||||
expected = [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"-m",
|
||||
f"axolotl.cli.{command}",
|
||||
str(config_path),
|
||||
"--debug-num-examples",
|
||||
"0",
|
||||
"--debug=False",
|
||||
"--debug-text-only=False",
|
||||
"--debug-num-examples=0",
|
||||
]
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
if train:
|
||||
expected.append("--shard=False")
|
||||
|
||||
if command == "train":
|
||||
assert mock.call_args.args[0] == "accelerate"
|
||||
assert mock.call_args.args[1] == expected
|
||||
else:
|
||||
assert mock.call_args.args[0] == expected
|
||||
assert mock.call_args.kwargs == {"check": True}
|
||||
assert result.exit_code == 0
|
||||
|
||||
def _test_cli_overrides(self, tmp_path: Path, valid_test_config: str):
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
|
||||
def test_evaluate_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "evaluate")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "evaluate", train=False
|
||||
)
|
||||
|
||||
def test_evaluate_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -33,7 +35,8 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -55,7 +58,8 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
"2",
|
||||
"--sequence-len",
|
||||
"128",
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -65,3 +69,104 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
cfg = mock_evaluate.call_args[0][0]
|
||||
assert cfg.micro_batch_size == 2
|
||||
assert cfg.sequence_len == 128
|
||||
|
||||
def test_evaluate_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test evaluate with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.evaluate" in called_cmd
|
||||
|
||||
def test_evaluate_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing evaluate commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -10,7 +10,7 @@ def test_inference_basic(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate"],
|
||||
["inference", str(config_path), "--launcher", "python"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
@@ -23,9 +23,124 @@ def test_inference_gradio(cli_runner, config_path):
|
||||
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||
["inference", str(config_path), "--launcher", "python", "--gradio"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_torchrun(cli_runner, config_path):
|
||||
"""Test inference with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_with_launcher_args_accelerate(cli_runner, config_path):
|
||||
"""Test inference with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_gradio_with_launcher_args(cli_runner, config_path):
|
||||
"""Test inference with gradio and launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--gradio",
|
||||
"--",
|
||||
"--num_processes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify both gradio flag and launcher args are present
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--num_processes=2" in called_cmd
|
||||
assert "--gradio" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.inference" in called_cmd
|
||||
|
||||
|
||||
def test_inference_backward_compatibility_no_launcher_args(cli_runner, config_path):
|
||||
"""Test that existing inference commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"inference",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -18,11 +18,10 @@ def test_build_command():
|
||||
assert result == [
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--learning-rate",
|
||||
"0.0001",
|
||||
"--batch-size",
|
||||
"8",
|
||||
"--debug",
|
||||
"--learning-rate=0.0001",
|
||||
"--batch-size=8",
|
||||
"--debug=True",
|
||||
"--use-fp16=False",
|
||||
]
|
||||
|
||||
|
||||
@@ -38,7 +37,7 @@ def test_invalid_command_options(cli_runner):
|
||||
],
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "No such option" in result.output
|
||||
assert "does not exist" in result.output
|
||||
|
||||
|
||||
def test_required_config_argument(cli_runner):
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -11,9 +9,101 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||
cli,
|
||||
["merge-sharded-fsdp-weights", str(config_path), "--launcher", "python"],
|
||||
)
|
||||
|
||||
assert mock.called
|
||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_torchrun(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with torchrun launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_with_launcher_args_accelerate(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test merge-sharded-fsdp-weights with accelerate launcher arguments"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.merge_sharded_fsdp_weights" in called_cmd
|
||||
|
||||
|
||||
def test_merge_sharded_fsdp_weights_backward_compatibility_no_launcher_args(
|
||||
cli_runner, config_path
|
||||
):
|
||||
"""Test that existing merge-sharded-fsdp-weights commands work without launcher args"""
|
||||
with patch("subprocess.run") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"merge-sharded-fsdp-weights",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
called_cmd = mock_subprocess.call_args.args[0]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert len(launcher_section) == 0 # No launcher args between 'launch' and '-m'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
unit tests for generating sweep configurations
|
||||
"""
|
||||
|
||||
from axolotl.cli.main import generate_sweep_configs
|
||||
from axolotl.cli.utils import generate_sweep_configs
|
||||
|
||||
|
||||
def test_generate_sweep_configs_no_pairs():
|
||||
|
||||
@@ -18,7 +18,9 @@ class TestTrainCommand(BaseCliTest):
|
||||
|
||||
def test_train_basic_execution(self, cli_runner, tmp_path, valid_test_config):
|
||||
"""Test basic successful execution"""
|
||||
self._test_basic_execution(cli_runner, tmp_path, valid_test_config, "train")
|
||||
self._test_basic_execution(
|
||||
cli_runner, tmp_path, valid_test_config, "train", train=True
|
||||
)
|
||||
|
||||
def test_train_basic_execution_no_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
@@ -37,7 +39,8 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--no-accelerate",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -59,11 +62,10 @@ class TestTrainCommand(BaseCliTest):
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
"--micro-batch-size",
|
||||
"2",
|
||||
"--no-accelerate",
|
||||
"--learning-rate=1e-4",
|
||||
"--micro-batch-size=2",
|
||||
"--launcher",
|
||||
"python",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
@@ -73,3 +75,177 @@ class TestTrainCommand(BaseCliTest):
|
||||
cfg = mock_train.call_args[1]["cfg"]
|
||||
assert cfg["learning_rate"] == 1e-4
|
||||
assert cfg["micro_batch_size"] == 2
|
||||
|
||||
def test_train_with_launcher_args_torchrun(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with torchrun launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=2",
|
||||
"--nnodes=1",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to torchrun
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "torchrun"
|
||||
assert "--nproc_per_node=2" in called_cmd
|
||||
assert "--nnodes=1" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_with_launcher_args_accelerate(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with accelerate launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--",
|
||||
"--config_file=accelerate_config.yml",
|
||||
"--num_processes=4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify launcher args are passed to accelerate
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
assert "--config_file=accelerate_config.yml" in called_cmd
|
||||
assert "--num_processes=4" in called_cmd
|
||||
assert "-m" in called_cmd
|
||||
assert "axolotl.cli.train" in called_cmd
|
||||
|
||||
def test_train_backward_compatibility_no_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test that existing train commands work without launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"accelerate",
|
||||
"--learning-rate",
|
||||
"1e-4",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
# Verify no launcher args contamination
|
||||
assert mock_subprocess.call_args.args[0] == "accelerate"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
assert called_cmd[0] == "accelerate"
|
||||
assert called_cmd[1] == "launch"
|
||||
# Should not contain any extra launcher args
|
||||
launcher_section = called_cmd[2 : called_cmd.index("-m")]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
) # No launcher args between 'launch' and '-m'
|
||||
|
||||
def test_train_mixed_args_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with both regular CLI args and launcher args"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch("os.execvpe") as mock_subprocess:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--learning-rate",
|
||||
"2e-4",
|
||||
"--micro-batch-size",
|
||||
"4",
|
||||
"--",
|
||||
"--nproc_per_node=8",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_subprocess.assert_called_once()
|
||||
|
||||
assert mock_subprocess.call_args.args[0] == "torchrun"
|
||||
called_cmd = mock_subprocess.call_args.args[1]
|
||||
# Verify launcher args
|
||||
assert "--nproc_per_node=8" in called_cmd
|
||||
# Verify axolotl args are also present
|
||||
assert "--learning-rate=2e-4" in called_cmd
|
||||
assert "--micro-batch-size=4" in called_cmd
|
||||
|
||||
def test_train_cloud_with_launcher_args(
|
||||
self, cli_runner, tmp_path, valid_test_config
|
||||
):
|
||||
"""Test train with cloud and launcher arguments"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
cloud_path = tmp_path / "cloud.yml"
|
||||
cloud_path.write_text("provider: modal\ngpu: a100")
|
||||
|
||||
with patch("axolotl.cli.cloud.do_cli_train") as mock_cloud_train:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"train",
|
||||
str(config_path),
|
||||
"--cloud",
|
||||
str(cloud_path),
|
||||
"--launcher",
|
||||
"torchrun",
|
||||
"--",
|
||||
"--nproc_per_node=4",
|
||||
"--nnodes=2",
|
||||
],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_cloud_train.assert_called_once()
|
||||
|
||||
# Verify cloud training was called with launcher args
|
||||
call_kwargs = mock_cloud_train.call_args.kwargs
|
||||
assert call_kwargs["launcher"] == "torchrun"
|
||||
assert call_kwargs["launcher_args"] == ["--nproc_per_node=4", "--nnodes=2"]
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""pytest tests for axolotl CLI utils."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -25,7 +23,7 @@ MOCK_TREE_RESPONSE = {
|
||||
def mock_responses():
|
||||
"""Mock responses for API and file downloads"""
|
||||
|
||||
def mock_get(url, timeout=None): # pylint: disable=unused-argument
|
||||
def mock_get(url, timeout=None):
|
||||
response = Mock()
|
||||
if "api.github.com" in url:
|
||||
response.text = json.dumps(MOCK_TREE_RESPONSE)
|
||||
@@ -72,3 +70,160 @@ def test_fetch_from_github_network_error():
|
||||
with patch("requests.get", side_effect=requests.RequestException):
|
||||
with pytest.raises(requests.RequestException):
|
||||
fetch_from_github("examples/", None)
|
||||
|
||||
|
||||
def assert_launcher_args_in_command(
|
||||
mock_subprocess_call,
|
||||
launcher: str,
|
||||
expected_launcher_args: list[str],
|
||||
command_module: str,
|
||||
):
|
||||
"""
|
||||
Helper function to verify launcher arguments are properly passed in subprocess calls.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
expected_launcher_args: List of expected launcher arguments
|
||||
command_module: Expected module name (e.g., "axolotl.cli.train")
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
# Verify launcher
|
||||
assert called_cmd[0] == launcher, (
|
||||
f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
)
|
||||
|
||||
# Verify launcher args are present
|
||||
for arg in expected_launcher_args:
|
||||
assert arg in called_cmd, (
|
||||
f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
)
|
||||
|
||||
# Verify module is present
|
||||
assert "-m" in called_cmd, "Expected -m flag for module execution"
|
||||
assert command_module in called_cmd, (
|
||||
f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
)
|
||||
|
||||
|
||||
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
"""
|
||||
Helper function to verify no unwanted launcher arguments are present.
|
||||
|
||||
Args:
|
||||
mock_subprocess_call: The mock subprocess.run call
|
||||
launcher: Expected launcher ("accelerate", "torchrun", etc.)
|
||||
"""
|
||||
assert mock_subprocess_call.called, "subprocess.run should have been called"
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
if launcher == "accelerate":
|
||||
# For accelerate, launcher args should be between 'launch' and '-m'
|
||||
launch_idx = called_cmd.index("launch")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[launch_idx + 1 : m_idx]
|
||||
assert len(launcher_section) == 0, (
|
||||
f"Unexpected launcher args found: {launcher_section}"
|
||||
)
|
||||
elif launcher == "torchrun":
|
||||
# For torchrun, launcher args should be between 'torchrun' and '-m'
|
||||
torchrun_idx = called_cmd.index("torchrun")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
|
||||
assert len(launcher_section) == 0, (
|
||||
f"Unexpected launcher args found: {launcher_section}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_launcher_args():
|
||||
"""Fixture providing common launcher argument combinations for testing."""
|
||||
return {
|
||||
"torchrun": ["--nproc_per_node=2", "--nnodes=1"],
|
||||
"accelerate": ["--config_file=accelerate_config.yml", "--num_processes=4"],
|
||||
}
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_endpoint():
|
||||
"""Test that default RDZV args are added when rdzv_endpoint is present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--rdzv_endpoint=127.0.0.1:29400"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should have added rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
# Original args should still be present
|
||||
assert "--nnodes=2" in result
|
||||
assert "--rdzv_endpoint=127.0.0.1:29400" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_backend():
|
||||
"""Test that existing rdzv_backend is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_backend
|
||||
backend_count = sum(1 for arg in result if "--rdzv_backend" in arg)
|
||||
assert backend_count == 1
|
||||
assert "--rdzv_backend=static" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_existing_id():
|
||||
"""Test that existing rdzv_id is not overridden."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_id=my_job_123",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add another rdzv_id
|
||||
id_count = sum(1 for arg in result if "--rdzv_id" in arg)
|
||||
assert id_count == 1
|
||||
assert "--rdzv_id=my_job_123" in result
|
||||
|
||||
# Should still add rdzv_backend
|
||||
assert "--rdzv_backend" in result
|
||||
assert "c10d" in result
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_without_endpoint():
|
||||
"""Test that no RDZV args are added when rdzv_endpoint is not present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = ["--nnodes=2", "--nproc_per_node=4"]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any rdzv args
|
||||
assert "--rdzv_backend" not in result
|
||||
assert result == launcher_args
|
||||
|
||||
|
||||
def test_add_default_rdzv_args_with_all_existing():
|
||||
"""Test that no defaults are added when all RDZV args are present."""
|
||||
from axolotl.cli.utils.train import _add_default_rdzv_args
|
||||
|
||||
launcher_args = [
|
||||
"--nnodes=2",
|
||||
"--rdzv_endpoint=127.0.0.1:29400",
|
||||
"--rdzv_backend=static",
|
||||
"--rdzv_id=existing_job",
|
||||
]
|
||||
result = _add_default_rdzv_args(launcher_args)
|
||||
|
||||
# Should not add any additional args
|
||||
assert len(result) == len(launcher_args)
|
||||
assert result == launcher_args
|
||||
|
||||
@@ -2,32 +2,38 @@
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import LocalEntryNotFoundError
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.hf_offline_utils import (
|
||||
enable_hf_offline,
|
||||
hf_offline_context,
|
||||
)
|
||||
|
||||
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
# pylint: disable=duplicate-code
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
def wrapper(*args, **kwargs):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -162,7 +168,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
@@ -170,7 +176,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
@@ -350,7 +356,7 @@ def download_llama32_1b_model_fixture():
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
|
||||
@@ -361,7 +367,7 @@ def tokenizer_huggyllama(
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
@@ -377,7 +383,7 @@ def tokenizer_huggyllama_w_special_tokens(
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
@@ -387,7 +393,7 @@ def tokenizer_llama2_7b(
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@@ -409,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
def temp_dir() -> Generator[str, None, None]:
|
||||
# Create a temporary directory
|
||||
_temp_dir = tempfile.mkdtemp()
|
||||
yield _temp_dir
|
||||
@@ -417,6 +423,11 @@ def temp_dir():
|
||||
shutil.rmtree(_temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def torch_manual_seed():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
@@ -428,9 +439,7 @@ def cleanup_monkeypatches():
|
||||
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_inner_training_loop = Trainer._inner_training_loop
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
@@ -438,9 +447,7 @@ def cleanup_monkeypatches():
|
||||
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer._inner_training_loop = original_trainer_inner_training_loop
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
|
||||
# Reset other known monkeypatches
|
||||
@@ -476,7 +483,7 @@ def cleanup_monkeypatches():
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
@@ -484,7 +491,7 @@ def dataset_winglian_tiny_shakespeare(
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -492,7 +499,7 @@ def dataset_tatsu_lab_alpaca(
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -500,7 +507,7 @@ def dataset_mhenrichsen_alpaca_2k_test(
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
@@ -511,7 +518,7 @@ def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -519,7 +526,7 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
@@ -527,7 +534,23 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
@pytest.fixture(name="min_base_cfg")
|
||||
def fixture_min_base_cfg():
|
||||
return DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
learning_rate=1e-3,
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
micro_batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
reason="Not running in CI cache preload",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
This module contains constants and configuration dictionaries used for
|
||||
datasets and other utilities in the Axolotl project, specifically for testing.
|
||||
"""
|
||||
|
||||
# Configuration for Alpaca Messages Dataset
|
||||
ALPACA_MESSAGES_CONFIG_OG = {
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Unit tests for axolotl.core.builders"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -12,7 +10,7 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -64,7 +62,8 @@ def fixture_base_cfg():
|
||||
"dataloader_num_workers": 1,
|
||||
"dataloader_pin_memory": True,
|
||||
"dataloader_prefetch_factor": 2,
|
||||
"sequence_parallel_degree": 1,
|
||||
"context_parallel_size": 1,
|
||||
"tensor_parallel_size": 1,
|
||||
# Dtype
|
||||
"fp16": False,
|
||||
"bf16": False,
|
||||
@@ -81,6 +80,7 @@ def fixture_base_cfg():
|
||||
"ddp_timeout": 1800,
|
||||
"ddp_bucket_cap_mb": 25,
|
||||
"ddp_broadcast_buffers": False,
|
||||
"dataset_processes": 4,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -279,7 +279,9 @@ class TestHFRLTrainerBuilder:
|
||||
# Other settings
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
assert training_arguments.gradient_checkpointing is False
|
||||
|
||||
# TODO(wing): restore once trl releases 0.22.0
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
@@ -326,7 +328,6 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
)
|
||||
|
||||
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
|
||||
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
self._write_rewards_file(rewards_dir)
|
||||
|
||||
@@ -439,6 +440,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
|
||||
cfg["dataset_num_proc"] = 4
|
||||
|
||||
if cfg_string == "grpo_cfg":
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
@@ -451,15 +453,19 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
# Only use mock for the commented out configs
|
||||
if dataset_name is not None:
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_w_config"
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
mock_load_dataset.return_value = request.getfixturevalue(
|
||||
dataset_name
|
||||
)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
else:
|
||||
# Load actual datasets for orpo_cfg and kto_cfg
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
builder.train_dataset = train_dataset
|
||||
builder.eval_dataset = eval_dataset
|
||||
@@ -468,7 +474,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
@@ -550,7 +556,7 @@ class TestHFCausalTrainerBuilder:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
@@ -590,6 +596,6 @@ class TestTrainerClsPlugin:
|
||||
except TypeError as e:
|
||||
# Error raised if trainer_cls is None
|
||||
assert "'tuple' object has no attribute 'config'" not in str(e)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
except Exception:
|
||||
# Another error happens, so we passed trainer_cls to builder
|
||||
pass
|
||||
|
||||
@@ -4,7 +4,6 @@ Simple end-to-end test for Cut Cross Entropy integration
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
@@ -13,8 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def min_cfg(temp_dir):
|
||||
@@ -45,6 +42,7 @@ def min_cfg(temp_dir):
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -53,14 +51,12 @@ class TestCutCrossEntropyIntegration:
|
||||
e2e tests for cut_cross_entropy integration with Axolotl
|
||||
"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||
cfg = DictDefault(min_cfg)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -70,7 +66,6 @@ class TestCutCrossEntropyIntegration:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_qwen2_w_cce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -100,13 +95,13 @@ class TestCutCrossEntropyIntegration:
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
@@ -134,8 +129,7 @@ class TestCutCrossEntropyIntegration:
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
|
||||
61
tests/e2e/integrations/test_fp8.py
Normal file
61
tests/e2e/integrations/test_fp8.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Simple end-to-end smoke tests for FP8 mixed precision training
|
||||
"""
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists, require_torch_2_7_0
|
||||
|
||||
|
||||
class FP8IntegrationTestCase:
|
||||
"""
|
||||
e2e smoke tests for FP8 mixed precision training with Axolotl
|
||||
"""
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_fp8_single_gpu_smoke(self, temp_dir):
|
||||
"""Smoke test for single GPU FP8 + torch.compile training"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3, # Very short smoke test
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"sdp_attention": True,
|
||||
"pad_to_seq_len": True,
|
||||
"sample_packing": True,
|
||||
"fp8": True,
|
||||
"torch_compile": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -5,7 +5,6 @@ e2e tests to make sure all the hooks are fired on the plugin
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.train import train
|
||||
@@ -29,85 +28,81 @@ class LogHooksPlugin(BasePlugin):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
def post_trainer_create(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_trainer_create\n")
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
def pre_model_load(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_model_load\n")
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_model_build(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_build\n")
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def pre_lora_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_lora_load\n")
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_lora_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_lora_load\n")
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_model_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_load\n")
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_optimizer\n")
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
|
||||
def get_trainer_cls(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("get_trainer_cls\n")
|
||||
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer, num_training_steps
|
||||
): # pylint: disable=unused-argument
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_steps):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_lr_scheduler\n")
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_pre_trainer\n")
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_post_trainer\n")
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_train(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_train\n")
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
def post_train_unload(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
@@ -120,7 +115,6 @@ class TestPluginHooks:
|
||||
"""
|
||||
|
||||
def test_plugin_hooks(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -154,14 +148,14 @@ class TestPluginHooks:
|
||||
"max_steps": 5,
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,11 +5,9 @@ e2e tests for kd trainer support in Axolotl
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
@@ -18,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
@pytest.fixture(name="kd_min_cfg")
|
||||
def min_cfg(temp_dir):
|
||||
return {
|
||||
"base_model": "osllmai-community/Llama-3.2-1B",
|
||||
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"tokenizer_config": "winglian/qwen3-14b-math",
|
||||
"plugins": [
|
||||
"axolotl.integrations.kd.KDPlugin",
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
@@ -27,25 +25,27 @@ def min_cfg(temp_dir):
|
||||
"liger_rms_norm": True,
|
||||
"liger_glu_activation": True,
|
||||
"torch_compile": True,
|
||||
"chat_template": "llama3",
|
||||
"chat_template": "qwen3",
|
||||
"kd_trainer": True,
|
||||
"kd_ce_alpha": 0.1,
|
||||
"kd_alpha": 0.9,
|
||||
"kd_temperature": 1.0,
|
||||
"kd_beta": 0.0,
|
||||
"kd_normalize_topk": True,
|
||||
"dataloader_prefetch_factor": 8,
|
||||
"dataloader_num_workers": 4,
|
||||
"dataloader_pin_memory": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
|
||||
"type": "axolotl.integrations.kd.chat_template",
|
||||
"field_messages": "messages_combined",
|
||||
"path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
|
||||
"type": "chat_template",
|
||||
"split": "train",
|
||||
"logprobs_field": "llm_text_generation_vllm_logprobs",
|
||||
"temperature": 1.0,
|
||||
"preprocess_shards": 2,
|
||||
"split_thinking": True,
|
||||
"eot_tokens": ["<|im_end|>"],
|
||||
"data_files": ["train/batch-000000.parquet"],
|
||||
},
|
||||
],
|
||||
"skip_prepare_dataset": True,
|
||||
"val_set_size": 0.0,
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
@@ -67,6 +67,7 @@ def min_cfg(temp_dir):
|
||||
"output_dir": temp_dir,
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -80,14 +81,24 @@ class TestKnowledgeDistillation:
|
||||
@require_torch_2_5_1
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"1",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
|
||||
@@ -108,17 +119,30 @@ class TestKnowledgeDistillation:
|
||||
"lora_r": 16,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"lora_mlp_kernel": False,
|
||||
"lora_qkv_kernel": False,
|
||||
"lora_o_kernel": False,
|
||||
}
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"1",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -18,7 +17,6 @@ class LigerIntegrationTestCase:
|
||||
|
||||
@require_torch_2_4_1
|
||||
def test_llama_wo_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -51,21 +49,20 @@ class LigerIntegrationTestCase:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@require_torch_2_4_1
|
||||
def test_llama_w_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -98,14 +95,14 @@ class LigerIntegrationTestCase:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
@@ -82,14 +81,14 @@ class TestLLMCompressorIntegration:
|
||||
},
|
||||
"save_compressed": save_compressed,
|
||||
},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
prepare_plugins(cfg)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
try:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -19,8 +19,15 @@ def test_geglu_forward_shape():
|
||||
assert out.device == gate.device
|
||||
|
||||
|
||||
def test_geglu_forward_values():
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
)
|
||||
def test_geglu_forward_values(torch_seed):
|
||||
"""Test GEGLU forward pass matches PyTorch reference implementation."""
|
||||
torch.manual_seed(torch_seed)
|
||||
|
||||
gate = torch.randn(2, 3, 64, device="cuda")
|
||||
up = torch.randn(2, 3, 64, device="cuda")
|
||||
|
||||
@@ -33,6 +40,7 @@ def test_geglu_forward_values():
|
||||
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=1, delay=5)
|
||||
@pytest.mark.parametrize(
|
||||
"torch_seed",
|
||||
[0, 42],
|
||||
@@ -77,6 +85,6 @@ def test_geglu_inplace_preservation():
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
assert not torch.equal(grad_output, grad_copy), (
|
||||
"Grad output should be modified in-place"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for LoRA custom autograd."""
|
||||
|
||||
# pylint: disable=invalid-name,redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
@@ -64,6 +62,7 @@ def sample_tensors():
|
||||
batch_size, seq_len, hidden_dim, device="cuda", dtype=torch.float16
|
||||
),
|
||||
"W": torch.randn(out_dim, hidden_dim, device="cuda", dtype=torch.float16),
|
||||
"b": torch.randn(out_dim, device="cuda", dtype=torch.float16),
|
||||
"scale": 0.5,
|
||||
"shapes": {
|
||||
"batch": batch_size,
|
||||
@@ -103,23 +102,24 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
assert b.shape == (128,)
|
||||
assert A.shape == (8, 64)
|
||||
assert B.shape == (128, 8)
|
||||
assert s == 0.5
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
@@ -127,6 +127,7 @@ def test_matmul_lora(sample_tensors):
|
||||
"""Tests matmul_lora function"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -138,19 +139,20 @@ def test_matmul_lora(sample_tensors):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test base matmul
|
||||
out1 = matmul_lora(X, W, None, None, None, None)
|
||||
expected1 = torch.matmul(X, W.t())
|
||||
out1 = matmul_lora(X, W, b, None, None, None, None)
|
||||
matmul = torch.matmul(X, W.t())
|
||||
expected1 = matmul + b
|
||||
assert torch.allclose(out1, expected1, rtol=1e-3)
|
||||
|
||||
# Test with LoRA
|
||||
out2 = matmul_lora(X, W, None, A, B, scale)
|
||||
out2 = matmul_lora(X, W, b, None, A, B, scale)
|
||||
lora_term = scale * torch.matmul(torch.matmul(X, A.t()), B.t())
|
||||
expected2 = expected1 + lora_term
|
||||
expected2 = matmul + lora_term + b
|
||||
assert torch.allclose(out2, expected2, rtol=1e-3)
|
||||
|
||||
# Test 3D input reshaping
|
||||
X_3d = X.clone()
|
||||
out3 = matmul_lora(X_3d, W, None, A, B, scale)
|
||||
out3 = matmul_lora(X_3d, W, b, None, A, B, scale)
|
||||
assert out3.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
|
||||
@@ -175,16 +177,19 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
@@ -243,16 +248,19 @@ def test_lora_mlp_with_adapters(
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
@@ -323,6 +331,7 @@ def test_lora_qkv(sample_tensors):
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
@@ -330,16 +339,19 @@ def test_lora_qkv(sample_tensors):
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -356,16 +368,19 @@ def test_lora_qkv(sample_tensors):
|
||||
X,
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
@@ -399,6 +414,7 @@ def test_lora_o(sample_tensors):
|
||||
"""Tests LoRA output projection"""
|
||||
X = sample_tensors["X"]
|
||||
W = sample_tensors["W"]
|
||||
b = sample_tensors["b"]
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -411,7 +427,7 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, None, A, B, scale)
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -425,6 +441,7 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
"""Tests LoRA with quantized weights"""
|
||||
X = sample_tensors["X"] # [batch, seq, hidden]
|
||||
W = sample_tensors["W"] # [out, hidden]
|
||||
b = sample_tensors["b"] # [out]
|
||||
scale = 0.5
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -436,13 +453,13 @@ def test_with_quantization(sample_tensors, mock_quantstate):
|
||||
B = torch.randn(out_dim, rank, device="cuda", dtype=torch.float16)
|
||||
|
||||
# Test matmul with quantization
|
||||
out = matmul_lora(X, W, mock_quantstate, A, B, scale)
|
||||
out = matmul_lora(X, W, b, mock_quantstate, A, B, scale)
|
||||
assert out.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
# Test with different batch sizes
|
||||
X2 = torch.randn(4, 6, hidden_dim, device="cuda", dtype=torch.float16)
|
||||
out2 = matmul_lora(X2, W, mock_quantstate, A, B, scale)
|
||||
out2 = matmul_lora(X2, W, b, mock_quantstate, A, B, scale)
|
||||
assert out2.shape == (4, 6, W.shape[0])
|
||||
assert not torch.isnan(out2).any()
|
||||
|
||||
@@ -459,11 +476,12 @@ def test_shapes_and_dimensions(batch, seq, hidden, rank, out):
|
||||
"""Tests various input shapes and dimensions"""
|
||||
X = torch.randn(batch, seq, hidden, device="cuda", dtype=torch.float16)
|
||||
W = torch.randn(out, hidden, device="cuda", dtype=torch.float16)
|
||||
b = torch.randn(out, device="cuda", dtype=torch.float16)
|
||||
A = torch.randn(rank, hidden, device="cuda", dtype=torch.float16)
|
||||
B = torch.randn(out, rank, device="cuda", dtype=torch.float16)
|
||||
scale = 0.5
|
||||
|
||||
result = matmul_lora(X, W, None, A, B, scale)
|
||||
result = matmul_lora(X, W, b, None, A, B, scale)
|
||||
assert result.shape == (batch, seq, out)
|
||||
|
||||
|
||||
@@ -471,6 +489,7 @@ def test_gradient_flow(sample_tensors):
|
||||
"""Tests gradient flow through LoRA layers"""
|
||||
X = sample_tensors["X"].clone()
|
||||
W = sample_tensors["W"].clone()
|
||||
b = sample_tensors["b"].clone()
|
||||
scale = sample_tensors["scale"]
|
||||
|
||||
shapes = sample_tensors["shapes"]
|
||||
@@ -486,7 +505,7 @@ def test_gradient_flow(sample_tensors):
|
||||
B.requires_grad = True
|
||||
|
||||
# Forward pass
|
||||
out = matmul_lora(X, W, None, A, B, scale)
|
||||
out = matmul_lora(X, W, b, None, A, B, scale)
|
||||
loss = out.sum()
|
||||
|
||||
# Backward pass
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for SwiGLU activation function Triton kernels."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -74,6 +72,6 @@ def test_swiglu_inplace_preservation():
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
assert not torch.equal(grad_output, grad_copy), (
|
||||
"Grad output should be modified in-place"
|
||||
)
|
||||
|
||||
@@ -54,6 +54,7 @@ class TestSequenceParallelism:
|
||||
"micro_batch_size": micro_batch_size,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -66,8 +67,9 @@ class TestSequenceParallelism:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"ring_attn_func": ring_attn_func,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -91,7 +93,10 @@ class TestSequenceParallelism:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", threshold, "Train Loss is too high"
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
threshold,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -100,13 +105,13 @@ class TestSequenceParallelism:
|
||||
(True, 1, True, None, 2.5), # defaults to varlen_llama3 ring_attn_func
|
||||
(False, 2, True, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, True, "batch_zigzag", 2.5),
|
||||
(False, 2, False, None, 2.5), # defaults to batch_ring ring_attn_func
|
||||
# (False, 2, False, None, 2.65), # defaults to batch_ring ring_attn_func
|
||||
],
|
||||
ids=[
|
||||
"sample_packing, varlen_llama3 ring_attn_func",
|
||||
"no sample_packing, pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_zigzag ring_attn_func",
|
||||
"no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
# "no sample_packing, no pad_to_sequence_len, batch_ring ring_attn_func",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_training(
|
||||
|
||||
@@ -31,7 +31,6 @@ class TestPackedFlex:
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_loss_llama(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -54,12 +53,14 @@ class TestPackedFlex:
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 2,
|
||||
"use_tensorboard": True,
|
||||
"save_strategy": "no",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -85,5 +86,5 @@ class TestPackedFlex:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -80,7 +80,7 @@ def start_vllm(
|
||||
cmd_env = env.copy()
|
||||
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
|
||||
# start `trl vllm-serve` command in the background and capture the process id
|
||||
process = subprocess.Popen( # pylint: disable=consider-using-with
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=cmd_env,
|
||||
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
@@ -105,7 +105,7 @@ def start_vllm(
|
||||
print(f"{i}: VLLM server failed to start: {str(exc)}")
|
||||
|
||||
# also check if the process.pid is still running
|
||||
if not process.poll() is None:
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
time.sleep(period_seconds)
|
||||
@@ -141,6 +141,7 @@ def recursive_kill(process: subprocess.Popen):
|
||||
os.kill(process.pid, 9)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="flaky vllm tests in modal")
|
||||
class TestGRPO:
|
||||
"""
|
||||
Test case for GRPO training using multilpe GPUs
|
||||
@@ -222,6 +223,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -296,7 +298,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_parallel_degree": 2,
|
||||
"context_parallel_size": 2,
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"special_tokens": {
|
||||
@@ -309,12 +311,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -400,12 +404,14 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"warmup_steps": 10,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestMultiGPUEval:
|
||||
"""
|
||||
|
||||
def test_eval_sample_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -38,12 +37,13 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.004,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
"split": "train[:5%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
@@ -51,6 +51,7 @@ class TestMultiGPUEval:
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -65,6 +66,7 @@ class TestMultiGPUEval:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -90,7 +92,6 @@ class TestMultiGPUEval:
|
||||
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
|
||||
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -107,12 +108,13 @@ class TestMultiGPUEval:
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||
"val_set_size": 0.0004,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "teknium/GPT4-LLM-Cleaned",
|
||||
"type": "alpaca",
|
||||
"split": "train[:5%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
@@ -120,6 +122,7 @@ class TestMultiGPUEval:
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -134,6 +137,7 @@ class TestMultiGPUEval:
|
||||
"logging_steps": 1,
|
||||
"weight_decay": 0.0,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
119
tests/e2e/multigpu/test_fp8_fsdp2.py
Normal file
119
tests/e2e/multigpu/test_fp8_fsdp2.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_hopper, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_fp8_training_success(temp_dir):
|
||||
"""Verify that FP8 training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFP8FSDP2:
|
||||
"""Test class for FP8 mixed precision with FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@require_hopper
|
||||
def test_fp8_fsdp2_smoke(self, temp_dir):
|
||||
"""Smoke test for 2-GPU FP8 + torch.compile + FSDP2 training"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3, # Very short smoke test
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused", # Use standard optimizer for stability
|
||||
"lr_scheduler": "cosine",
|
||||
"sdp_attention": True,
|
||||
"pad_to_seq_len": True,
|
||||
"sample_packing": True,
|
||||
# FP8 configuration
|
||||
"fp8": True,
|
||||
"fp8_enable_fsdp_float8_all_gather": True,
|
||||
"torch_compile": True,
|
||||
# FSDP2 configuration
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_fp8_training_success(temp_dir)
|
||||
324
tests/e2e/multigpu/test_fsdp1.py
Normal file
324
tests/e2e/multigpu/test_fsdp1.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""Test module for FSDP1 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP1:
|
||||
"""Test class for FSDP1 functionality."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_cpu_ram_efficient_loading",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter_config",
|
||||
[
|
||||
{
|
||||
"adapter": "lora",
|
||||
"load_in_4bit": False,
|
||||
},
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_lora_sft(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": adapter_config["adapter"],
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter_config",
|
||||
[
|
||||
{
|
||||
"adapter": "lora",
|
||||
"load_in_4bit": False,
|
||||
},
|
||||
{
|
||||
"adapter": "qlora",
|
||||
"load_in_4bit": True,
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_dpo_lora(self, temp_dir, adapter_config):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"load_in_4bit": adapter_config["load_in_4bit"],
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"sequence_len": 2048,
|
||||
"adapter": adapter_config["adapter"],
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": "1",
|
||||
"fsdp_config": {
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
480
tests/e2e/multigpu/test_fsdp2.py
Normal file
480
tests/e2e/multigpu/test_fsdp2.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""Test module for FSDP2 multi-GPU functionality."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from tbparse import SummaryReader
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def verify_training_success(temp_dir):
|
||||
"""Verify that training completed successfully by checking artifacts and loss."""
|
||||
output_path = Path(temp_dir)
|
||||
|
||||
model_files = list(output_path.glob("*.bin")) + list(
|
||||
output_path.glob("*.safetensors")
|
||||
)
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
event_files = sorted(os.listdir(tb_log_path))
|
||||
if event_files:
|
||||
event_file = os.path.join(tb_log_path, event_files[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP2:
|
||||
"""Test class for FSDP2 functionality."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_cpu_ram_efficient_loading",
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_sft(self, temp_dir, fsdp_cpu_ram_efficient_loading):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": fsdp_cpu_ram_efficient_loading,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize("peft_use_dora", [True, False])
|
||||
def test_lora_sft(self, temp_dir, peft_use_dora):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"peft_use_dora": peft_use_dora,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_qlora_sft_kernels(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_dpo_fft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
verify_training_success(temp_dir)
|
||||
@@ -29,7 +29,6 @@ class TestMultiGPUGemma3:
|
||||
"""
|
||||
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
|
||||
@@ -64,12 +63,14 @@ class TestMultiGPUGemma3:
|
||||
},
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -91,5 +92,5 @@ class TestMultiGPUGemma3:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 1.8, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -35,7 +35,6 @@ class TestMultiGPULlama:
|
||||
"""
|
||||
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -62,12 +61,14 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -89,7 +90,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.8, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -97,7 +98,6 @@ class TestMultiGPULlama:
|
||||
[1, 2],
|
||||
)
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -127,12 +127,14 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -154,11 +156,10 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -198,8 +199,9 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"gradient_checkpointing": False,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -207,6 +209,7 @@ class TestMultiGPULlama:
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -232,11 +235,10 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -276,8 +278,9 @@ class TestMultiGPULlama:
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"gradient_checkpointing": False,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -285,6 +288,7 @@ class TestMultiGPULlama:
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -310,7 +314,7 @@ class TestMultiGPULlama:
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -318,7 +322,6 @@ class TestMultiGPULlama:
|
||||
[1, 2],
|
||||
)
|
||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -340,6 +343,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -349,7 +353,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -359,6 +362,8 @@ class TestMultiGPULlama:
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -380,15 +385,17 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
[
|
||||
"FULL_STATE_DICT",
|
||||
# "SHARDED_STATE_DICT", # not supported since intermediate checkpoints fail with fsdp1
|
||||
],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -407,11 +414,13 @@ class TestMultiGPULlama:
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"max_steps": 3,
|
||||
"save_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -421,7 +430,6 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
@@ -431,6 +439,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -452,7 +461,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
@@ -467,7 +476,6 @@ class TestMultiGPULlama:
|
||||
def test_fsdp2_packed(
|
||||
self, temp_dir, attention_backend, fsdp_reshard_after_forward
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -491,6 +499,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
"gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -508,6 +517,7 @@ class TestMultiGPULlama:
|
||||
"fsdp_reshard_after_forward": fsdp_reshard_after_forward,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if attention_backend == "flash":
|
||||
@@ -533,11 +543,11 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip("regression failure from v4.57.0")
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||
@@ -573,6 +583,7 @@ class TestMultiGPULlama:
|
||||
"gradient_accumulation_steps": 2,
|
||||
# "gradient_checkpointing": True,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -582,16 +593,16 @@ class TestMultiGPULlama:
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": True,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -613,7 +624,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -635,7 +646,6 @@ class TestMultiGPULlama:
|
||||
def test_ds_zero3_packed(
|
||||
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -669,12 +679,14 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / deepspeed),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -697,7 +709,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.45, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -709,7 +721,6 @@ class TestMultiGPULlama:
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -743,12 +754,15 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||
"use_tensorboard": True,
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -771,7 +785,7 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -783,7 +797,6 @@ class TestMultiGPULlama:
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -817,12 +830,14 @@ class TestMultiGPULlama:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
**adapter,
|
||||
}
|
||||
)
|
||||
@@ -845,14 +860,13 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="fix untrained tokens brittle with lots of edge cases in latest transformers"
|
||||
)
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -891,6 +905,7 @@ class TestMultiGPULlama:
|
||||
"save_safetensors": True,
|
||||
# "deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -912,5 +927,5 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 4.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
192
tests/e2e/multigpu/test_locking.py
Normal file
192
tests/e2e/multigpu/test_locking.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for FileLockLoader class."""
|
||||
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.data.lock import FileLockLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestFileLockLoader:
|
||||
"""Class with tests for FileLockLoader."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(self):
|
||||
"""Create a temporary directory for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
yield Path(tmp_dir)
|
||||
|
||||
@pytest.fixture
|
||||
def cfg(self, temp_dir):
|
||||
"""Create a test configuration."""
|
||||
return DictDefault({"dataset_prepared_path": str(temp_dir)})
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, cfg):
|
||||
"""Create a FileLockLoader instance for testing."""
|
||||
return FileLockLoader(cfg)
|
||||
|
||||
def test_load_first_process(self, loader):
|
||||
"""Test load() when no ready flag exists (first process)."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should call the load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "test_data"
|
||||
|
||||
# Should create the ready flag
|
||||
assert loader.ready_flag_path.exists()
|
||||
|
||||
def test_load_subsequent_process(self, loader):
|
||||
"""Test load() when ready flag already exists (subsequent process)."""
|
||||
# Create ready flag first
|
||||
loader.ready_flag_path.touch()
|
||||
|
||||
mock_load_fn = Mock(return_value="loaded_data")
|
||||
|
||||
result = loader.load(mock_load_fn)
|
||||
|
||||
# Should still call load function (to load the prepared data)
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "loaded_data"
|
||||
|
||||
def test_load_concurrent_processes(self, cfg):
|
||||
"""Test that concurrent processes coordinate correctly."""
|
||||
results = []
|
||||
call_count = 0
|
||||
|
||||
def slow_load_fn():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1) # Simulate slow loading
|
||||
return f"data_{call_count}"
|
||||
|
||||
def worker():
|
||||
loader = FileLockLoader(cfg)
|
||||
result = loader.load(slow_load_fn)
|
||||
results.append(result)
|
||||
|
||||
# Start multiple threads simultaneously
|
||||
threads = [threading.Thread(target=worker) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Only one thread should have done the initial loading
|
||||
# All should return data, but the load function should be called
|
||||
# once by the first process and once by each subsequent process
|
||||
assert len(results) == 3
|
||||
assert all(result.startswith("data_") for result in results)
|
||||
|
||||
@patch("time.sleep")
|
||||
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
|
||||
"""Test that processes wait for the ready flag to appear."""
|
||||
mock_load_fn = Mock(return_value="waiting_data")
|
||||
mock_ready_flag_path = Mock()
|
||||
exists_call_count = 0
|
||||
|
||||
def mock_exists():
|
||||
nonlocal exists_call_count
|
||||
exists_call_count += 1
|
||||
|
||||
if exists_call_count == 1:
|
||||
# First check: ready flag exists (not first process)
|
||||
return True
|
||||
if exists_call_count <= 3:
|
||||
# While loop checks: flag doesn't exist yet
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_ready_flag_path.exists.side_effect = mock_exists
|
||||
|
||||
# Replace the ready_flag_path with our mock
|
||||
original_path = loader.ready_flag_path
|
||||
loader.ready_flag_path = mock_ready_flag_path
|
||||
|
||||
try:
|
||||
result = loader.load(mock_load_fn)
|
||||
finally:
|
||||
# Restore original path
|
||||
loader.ready_flag_path = original_path
|
||||
|
||||
# Should have slept twice while waiting
|
||||
assert mock_sleep.call_count == 2
|
||||
mock_sleep.assert_called_with(1)
|
||||
|
||||
# Should eventually call load function
|
||||
mock_load_fn.assert_called_once()
|
||||
assert result == "waiting_data"
|
||||
|
||||
def test_complete_workflow_with_cleanup(self, loader):
|
||||
"""Test the complete load -> cleanup workflow."""
|
||||
mock_load_fn = Mock(return_value="test_data")
|
||||
|
||||
# First process calls load (this should set up counter)
|
||||
result = loader.load(mock_load_fn)
|
||||
assert result == "test_data"
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.exists()
|
||||
|
||||
# Cleanup should remove everything since there's only one process
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_multiple_processes_workflow(self, loader):
|
||||
"""Test workflow with multiple processes."""
|
||||
# Simulate multiple processes by manually setting up counter
|
||||
loader.ready_flag_path.touch()
|
||||
loader.counter_path.write_text("3") # 3 processes
|
||||
|
||||
# First process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "2"
|
||||
|
||||
# Second process cleanup
|
||||
loader.cleanup()
|
||||
assert loader.ready_flag_path.exists()
|
||||
assert loader.counter_path.read_text().strip() == "1"
|
||||
|
||||
# Last process cleanup
|
||||
loader.cleanup()
|
||||
assert not loader.ready_flag_path.exists()
|
||||
assert not loader.counter_path.exists()
|
||||
|
||||
def test_load_exception_handling(self, loader):
|
||||
"""Test behavior when load_fn raises an exception."""
|
||||
|
||||
def failing_load_fn():
|
||||
raise ValueError("Load failed")
|
||||
|
||||
with pytest.raises(ValueError, match="Load failed"):
|
||||
loader.load(failing_load_fn)
|
||||
|
||||
# Ready flag should not be created on failure
|
||||
assert not loader.ready_flag_path.exists()
|
||||
|
||||
def test_file_lock_called(self, loader):
|
||||
"""Test that FileLock is properly used."""
|
||||
mock_load_fn = Mock(return_value="locked_data")
|
||||
|
||||
with patch("axolotl.utils.data.lock.FileLock") as mock_filelock:
|
||||
mock_context = MagicMock()
|
||||
mock_filelock.return_value.__enter__ = Mock(return_value=mock_context)
|
||||
mock_filelock.return_value.__exit__ = Mock(return_value=None)
|
||||
|
||||
loader.load(mock_load_fn)
|
||||
|
||||
# Verify FileLock was called with correct path
|
||||
mock_filelock.assert_called_once_with(str(loader.lock_file_path))
|
||||
|
||||
# Verify context manager was used
|
||||
mock_filelock.return_value.__enter__.assert_called_once()
|
||||
mock_filelock.return_value.__exit__.assert_called_once()
|
||||
@@ -1,92 +0,0 @@
|
||||
"""
|
||||
E2E tests for multigpu qwen2
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestMultiGPUQwen2:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"load_in_4bit": True,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"sequence_len": 2048,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Intel/orca_dpo_pairs",
|
||||
"split": "train",
|
||||
"type": "chatml.intel",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
# "gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {
|
||||
"use_reentrant": False,
|
||||
},
|
||||
"fsdp": [
|
||||
"full_shard",
|
||||
"auto_wrap",
|
||||
],
|
||||
"fsdp_config": {
|
||||
"fsdp_limit_all_gathers": True,
|
||||
"fsdp_offload_params": False,
|
||||
"fsdp_sync_module_states": True,
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
@@ -10,7 +10,10 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
from tests.e2e.utils import (
|
||||
check_tensorboard,
|
||||
require_torch_2_7_0,
|
||||
)
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -20,9 +23,8 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -48,6 +50,7 @@ class TestMultiGPURay:
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -55,6 +58,7 @@ class TestMultiGPURay:
|
||||
"use_tensorboard": True,
|
||||
"use_ray": True,
|
||||
"ray_num_workers": 2,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,16 +79,15 @@ class TestMultiGPURay:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -107,12 +110,14 @@ class TestMultiGPURay:
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero2.json"),
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -133,5 +138,72 @@ class TestMultiGPURay:
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@require_torch_2_7_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
)
|
||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--use-ray",
|
||||
"--ray-num-workers",
|
||||
"2",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
68
tests/e2e/multigpu/test_tp.py
Normal file
68
tests/e2e/multigpu/test_tp.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""multigpu e2e test for tensor parallelism."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_7_0
|
||||
|
||||
|
||||
class TestTensorParallel:
|
||||
"""Test class for Tensor Parallel functionality."""
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TP doesn't work with models with tied weights (embeddings)"
|
||||
)
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"tensor_parallel_size": 2,
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Integration tests for LoRA activation and attention kernels."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -25,7 +23,9 @@ from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
find_self_attn_in_layer,
|
||||
get_attention_cls_from_config,
|
||||
get_layers,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -86,7 +86,7 @@ def test_attention_patching_integration(model_name, attention_cls):
|
||||
cfg = DictDefault({"base_model": model_name})
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(attention_cls, "forward")
|
||||
original_forward = attention_cls.forward
|
||||
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
@@ -102,7 +102,7 @@ def test_attention_patching_integration(model_name, attention_cls):
|
||||
assert hasattr(attention_cls, "_original_forward")
|
||||
|
||||
# Clean up
|
||||
setattr(attention_cls, "forward", original_forward)
|
||||
attention_cls.forward = original_forward
|
||||
delattr(attention_cls, "_original_forward")
|
||||
|
||||
|
||||
@@ -160,7 +160,7 @@ def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
torch_dtype=torch.float16,
|
||||
dtype=torch.float16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
@@ -377,9 +377,9 @@ def test_model_architecture(model_config):
|
||||
|
||||
# Verify correct activation function
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert (
|
||||
layer.mlp.forward.__func__ is model_config["expected_activation"]
|
||||
), f"Wrong activation for {model_config['name']}"
|
||||
assert layer.mlp.forward.__func__ is model_config["expected_activation"], (
|
||||
f"Wrong activation for {model_config['name']}"
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
inputs = get_test_inputs(model)
|
||||
@@ -388,13 +388,12 @@ def test_model_architecture(model_config):
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
# Check outputs match
|
||||
assert torch.allclose(
|
||||
original_output, patched_output, rtol=1e-4
|
||||
), f"Outputs don't match for {model_config['name']}"
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4), (
|
||||
f"Outputs don't match for {model_config['name']}"
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration():
|
||||
def test_kernel_training_integration(temp_dir):
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
@@ -424,6 +423,14 @@ def test_kernel_training_integration():
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Load model
|
||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
@@ -501,3 +508,69 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
break
|
||||
|
||||
assert found_patched_attn
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
)
|
||||
|
||||
# Write cfg to yaml file
|
||||
path = Path(temp_dir) / "config.yaml"
|
||||
with open(path, "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||
|
||||
# Verify patch was not applied
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for multipack fft llama using 4d attention masks
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -56,19 +54,18 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -104,12 +101,12 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytest
|
||||
import transformers
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -33,10 +32,9 @@ class TestActivationCheckpointing:
|
||||
def test_activation_checkpointing_offload(
|
||||
self,
|
||||
temp_dir,
|
||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||
fix_checkpoint_after_test,
|
||||
gradient_checkpointing,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -70,13 +68,14 @@ class TestActivationCheckpointing:
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": gradient_checkpointing,
|
||||
"save_first_step": False,
|
||||
"dataset_num_proc": 4,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -10,7 +10,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
class TestPluginArgs:
|
||||
"""
|
||||
test class for plugin args loaded from the config file
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -24,7 +23,6 @@ class TestFAXentropyLlama:
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -63,6 +61,7 @@ class TestFAXentropyLlama:
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -73,12 +72,11 @@ class TestFAXentropyLlama:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -59,12 +57,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -72,7 +70,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -101,12 +98,12 @@ class TestFalconPatched(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
81
tests/e2e/patched/test_flattening.py
Normal file
81
tests/e2e/patched/test_flattening.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
E2E tests for flattening batches
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestFAFlattening:
|
||||
"""
|
||||
Test case for Llama models using LoRA w batch flattening
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"batch_flattening": True,
|
||||
"flash_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"field_messages": "conversations",
|
||||
"message_field_content": "value",
|
||||
"message_field_role": "from",
|
||||
"type": "chat_template",
|
||||
"split": "train[:2%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"save_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
30
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
30
tests/e2e/patched/test_fsdp2_qlora.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Integration tests for FSDP2 Params4bit patches."""
|
||||
|
||||
import pytest
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
|
||||
|
||||
|
||||
class TestFSDPPatchIntegration:
|
||||
"""Test FSDP patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_fsdp2_init_patches(self):
|
||||
"""Test that all patches can be applied together."""
|
||||
from axolotl.monkeypatch.fsdp2_qlora import (
|
||||
apply_init_sharded_param_patch,
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
original_init_sharded = FSDPParam._init_sharded_param
|
||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||
|
||||
# Apply patches
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
assert FSDPParam._init_sharded_param != original_init_sharded, (
|
||||
"_init_sharded_param was not patched"
|
||||
)
|
||||
assert FSDPParam.init_unsharded_param != original_init_unsharded, (
|
||||
"init_unsharded_param was not patched"
|
||||
)
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -24,13 +23,11 @@ class TestFusedLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -54,6 +51,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
"max_steps": 10,
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -62,8 +60,7 @@ class TestFusedLlama(unittest.TestCase):
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -59,20 +57,19 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -102,13 +99,13 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -56,6 +54,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -65,8 +64,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -74,7 +72,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||
@with_temp_dir
|
||||
def test_lora_gptq_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
|
||||
@@ -110,12 +107,12 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,13 +4,12 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
from ..utils import check_model_output_exists, require_torch_2_6_0, with_temp_dir
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
@@ -18,9 +17,9 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -56,19 +55,18 @@ class TestMistral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -98,12 +96,12 @@ class TestMistral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for mixtral
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -53,19 +51,18 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -92,12 +89,12 @@ class TestMixtral(unittest.TestCase):
|
||||
"save_steps": 3,
|
||||
"eval_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -45,6 +45,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -78,6 +79,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -87,5 +89,5 @@ class TestModelPatches(unittest.TestCase):
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ class TestLlamaPeftEmbeddings:
|
||||
"""
|
||||
|
||||
def test_peft_embeddings_upcast(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -49,6 +48,7 @@ class TestLlamaPeftEmbeddings:
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"embeddings_skip_upcast": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -55,20 +53,19 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -107,13 +104,13 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
"eval_steps": 3,
|
||||
"save_steps": 4,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import subprocess
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestResumeLlama:
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_resume_lora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -59,6 +57,7 @@ class TestResumeLlama:
|
||||
"max_steps": 15,
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -67,8 +66,7 @@ class TestResumeLlama:
|
||||
cfg.fp16 = True
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
@@ -78,7 +76,6 @@ class TestResumeLlama:
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,480 +0,0 @@
|
||||
"""Tests for sequence parallelism functionality."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,unused-argument
|
||||
|
||||
import functools
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
)
|
||||
from axolotl.utils.ctx_managers.sequence_parallel import apply_sequence_parallelism
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def partial_state():
|
||||
"""Create a real PartialState instance for testing."""
|
||||
state = PartialState()
|
||||
return state
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
def fixture_cfg():
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sequence_parallel_batch():
|
||||
"""Create a test batch for sequence parallelism tests."""
|
||||
batch_size = 1
|
||||
seq_len = 8
|
||||
|
||||
# Create test tensors
|
||||
input_ids = torch.arange(batch_size * seq_len).reshape(batch_size, seq_len)
|
||||
attention_mask = torch.ones(batch_size, seq_len)
|
||||
position_ids = torch.arange(seq_len).expand(batch_size, seq_len)
|
||||
labels = input_ids.clone()
|
||||
|
||||
# Create test batch
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class TestRingAttention:
|
||||
"""Tests for the ring attention functionality."""
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Verify that RuntimeError is raised when no group is registered
|
||||
with pytest.raises(
|
||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
||||
):
|
||||
get_ring_attn_group()
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("torch.distributed.get_world_size")
|
||||
def test_register_ring_attn(
|
||||
self, mock_world_size, mock_rank, mock_new_group, partial_state
|
||||
):
|
||||
"""Test that ring attention groups are created correctly."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 8 # 8 GPUs total
|
||||
mock_rank.return_value = 3 # GPU #3
|
||||
mock_group = MagicMock()
|
||||
mock_new_group.return_value = mock_group
|
||||
|
||||
# Call register_ring_attn with size 4
|
||||
register_ring_attn(
|
||||
sequence_parallel_degree=4,
|
||||
heads_k_stride=1,
|
||||
ring_attn_func=RingAttnFunc.VARLEN_LLAMA3,
|
||||
)
|
||||
|
||||
# Verify the number of calls without examining the arguments
|
||||
assert mock_new_group.call_count == 2
|
||||
|
||||
# Verify that new_group was called
|
||||
mock_new_group.assert_called()
|
||||
|
||||
# Clean up
|
||||
set_ring_attn_group(None)
|
||||
|
||||
|
||||
class TestConfigValidation:
|
||||
"""Tests for validating sequence parallelism configurations."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_mocks(self, monkeypatch):
|
||||
"""Set up mocks for all tests in this class."""
|
||||
# Mock the ring_flash_attn module
|
||||
monkeypatch.setitem(sys.modules, "ring_flash_attn", MagicMock())
|
||||
|
||||
@pytest.fixture
|
||||
def base_cfg(self):
|
||||
"""Create a base configuration for testing."""
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-3,
|
||||
"output_dir": "./model-out",
|
||||
"sequence_len": 512,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_updates, expected_values, should_pass, error_msg",
|
||||
[
|
||||
# Valid configuration
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
{"sequence_parallel_degree": 2, "flash_attention": True},
|
||||
True,
|
||||
None,
|
||||
),
|
||||
# Default sequence_parallel_degree
|
||||
({}, {"sequence_parallel_degree": 1}, True, None),
|
||||
# Invalid: sequence_parallel_degree > 1 without flash_attention
|
||||
(
|
||||
{"sequence_parallel_degree": 2, "flash_attention": False},
|
||||
None,
|
||||
False,
|
||||
"flash_attention: true must be set",
|
||||
),
|
||||
# Invalid: sequence_parallel_degree > 1 with sample_packing and micro_batch_size > 1
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"micro_batch_size": 2,
|
||||
"pad_to_sequence_len": True,
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"micro_batch_size must be set to 1",
|
||||
),
|
||||
# Valid: Basic GRPO config
|
||||
(
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
{
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": TRLConfig(use_liger_loss=True),
|
||||
},
|
||||
True,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
# Invalid: GRPO config with Liger loss
|
||||
(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"micro_batch_size": 2,
|
||||
"trl": {"use_liger_loss": True},
|
||||
},
|
||||
None,
|
||||
False,
|
||||
"GRPO + SP + Liger not currently supported",
|
||||
),
|
||||
],
|
||||
ids=[
|
||||
"valid_config",
|
||||
"default_sp_degree",
|
||||
"without_flash_attention",
|
||||
"sample_packing_with_large_batch",
|
||||
"valid_grpo",
|
||||
"grpo_with_liger_loss",
|
||||
],
|
||||
)
|
||||
def test_sequence_parallel_config_validation(
|
||||
self, base_cfg, config_updates, expected_values, should_pass, error_msg
|
||||
):
|
||||
"""Test various sequence parallelism configuration scenarios."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg
|
||||
cfg.update(config_updates)
|
||||
|
||||
if should_pass:
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check expected values
|
||||
for key, value in expected_values.items():
|
||||
assert getattr(config, key) == value
|
||||
else:
|
||||
# Should raise exception
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
assert error_msg in str(excinfo.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ring_attn_func, sample_packing, expected_func",
|
||||
[
|
||||
(None, True, RingAttnFunc.VARLEN_LLAMA3),
|
||||
(None, False, RingAttnFunc.BATCH_RING),
|
||||
],
|
||||
ids=["default_with_sample_packing", "default_without_sample_packing"],
|
||||
)
|
||||
def test_ring_attn_func_validation(
|
||||
self, base_cfg, ring_attn_func, sample_packing, expected_func
|
||||
):
|
||||
"""Test ring_attn_func validation and defaults."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Apply updates to base config
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"sample_packing": sample_packing,
|
||||
}
|
||||
|
||||
if ring_attn_func is not None:
|
||||
cfg["ring_attn_func"] = ring_attn_func
|
||||
|
||||
# Should validate without errors
|
||||
config = AxolotlInputConfig(**cfg)
|
||||
|
||||
# Check ring_attn_func value
|
||||
assert config.ring_attn_func.value == expected_func
|
||||
|
||||
def test_invalid_ring_attn_func(self, base_cfg):
|
||||
"""Test that an invalid ring_attn_func is rejected."""
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
# Invalid configuration with invalid ring_attn_func
|
||||
cfg = base_cfg | {
|
||||
"sequence_parallel_degree": 2,
|
||||
"flash_attention": True,
|
||||
"ring_attn_func": "INVALID_FUNC",
|
||||
}
|
||||
|
||||
# Should raise ValidationError
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AxolotlInputConfig(**cfg)
|
||||
|
||||
# Verify error message
|
||||
assert "Input should be 'varlen_llama3' or 'batch_ring'" in str(excinfo.value)
|
||||
|
||||
|
||||
class TestApplySequenceParallelism:
|
||||
"""Tests for the apply_sequence_parallelism function."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_distributed(self, monkeypatch):
|
||||
"""Mock torch.distributed functions for testing."""
|
||||
# Mock is_initialized to return True
|
||||
monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True)
|
||||
|
||||
# Mock get_rank to return 0 by default
|
||||
monkeypatch.setattr(torch.distributed, "get_rank", lambda *args, **kwargs: 0)
|
||||
|
||||
# Mock get_world_size to return 2 by default
|
||||
monkeypatch.setattr(
|
||||
torch.distributed, "get_world_size", lambda *args, **kwargs: 2
|
||||
)
|
||||
|
||||
# Mock the process group
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
# Mock update_ring_attn_params
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
local_world_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Check that sequence dimension was sharded correctly
|
||||
assert result["input_ids"].shape[1] == seq_len // 2
|
||||
assert result["attention_mask"].shape[1] == seq_len // 2
|
||||
|
||||
# Verify content: rank 0 should get the first half of the sequence
|
||||
assert torch.equal(result["input_ids"], batch["input_ids"][:, : seq_len // 2])
|
||||
assert torch.equal(
|
||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
||||
)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=1,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verify content: rank 1 should get the second half of the sequence
|
||||
assert torch.equal(result["input_ids"], original_input_ids[:, seq_len // 2 :])
|
||||
|
||||
# TODO(djsaunde): add back once implemented.
|
||||
# def test_batch_zigzag(self, sequence_parallel_batch):
|
||||
# """Test BATCH_ZIGZAG sharding pattern."""
|
||||
# batch = sequence_parallel_batch
|
||||
# original_input_ids = batch["input_ids"].clone()
|
||||
# seq_len = batch["input_ids"].size(1)
|
||||
|
||||
# # Test rank 0
|
||||
# result_rank0 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=0,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Test rank 1
|
||||
# result_rank1 = apply_sequence_parallelism(
|
||||
# batch={k: v.clone() for k, v in batch.items()},
|
||||
# local_rank=1,
|
||||
# local_world_size=2,
|
||||
# ring_attn_func=RingAttnFunc.BATCH_ZIGZAG,
|
||||
# )
|
||||
|
||||
# # Checks for both ranks
|
||||
# assert result_rank0["input_ids"].shape[1] == seq_len // 2
|
||||
# assert result_rank1["input_ids"].shape[1] == seq_len // 2
|
||||
|
||||
# # For a 2-rank system with 8 tokens, check specific zigzag pattern
|
||||
# # Rank 0 should get chunks [0, 1] and [6, 7]
|
||||
# # Rank 1 should get chunks [2, 3] and [4, 5]
|
||||
# if seq_len == 8:
|
||||
# # Create expected tensors for comparison
|
||||
# rank0_expected = torch.cat(
|
||||
# [original_input_ids[:, :2], original_input_ids[:, 6:8]], dim=1
|
||||
# )
|
||||
|
||||
# rank1_expected = torch.cat(
|
||||
# [original_input_ids[:, 2:4], original_input_ids[:, 4:6]], dim=1
|
||||
# )
|
||||
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# Create a partially applied function
|
||||
rank0_ring_parallel = functools.partial(
|
||||
apply_sequence_parallelism,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Use the partially applied function
|
||||
result, _, _ = rank0_ring_parallel(batch=batch)
|
||||
|
||||
# Verify it works as expected
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
assert torch.equal(
|
||||
result["input_ids"],
|
||||
original_input_ids[:, : original_input_ids.shape[1] // 2],
|
||||
)
|
||||
|
||||
def test_missing_position_ids(self, sequence_parallel_batch):
|
||||
"""Test handling of batch without position_ids."""
|
||||
# Create a batch without position_ids
|
||||
batch = {
|
||||
k: v for k, v in sequence_parallel_batch.items() if k != "position_ids"
|
||||
}
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
# This should run without error even though position_ids is missing
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=batch,
|
||||
local_rank=0,
|
||||
local_world_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
ring_attn_func=RingAttnFunc.BATCH_RING,
|
||||
)
|
||||
|
||||
# Verification should pass
|
||||
assert "position_ids" in result
|
||||
assert result["input_ids"].shape[1] == result["position_ids"].shape[1]
|
||||
assert result["input_ids"].shape[1] == original_input_ids.shape[1] // 2
|
||||
@@ -4,7 +4,6 @@ e2e tests for unsloth qlora
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -13,7 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
@@ -63,19 +61,19 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||
@@ -114,19 +112,19 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -170,17 +168,17 @@ class TestUnslothQLoRA:
|
||||
"lr_scheduler": "cosine",
|
||||
"use_tensorboard": True,
|
||||
"fp16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestPackedFlex(unittest.TestCase):
|
||||
@require_torch_2_6_0
|
||||
@with_temp_dir
|
||||
def test_loss_llama(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -50,6 +48,7 @@ class TestPackedFlex(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -59,11 +58,10 @@ class TestPackedFlex(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.1, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for relora llama
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -21,7 +20,6 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -35,9 +33,10 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora_steps": 50,
|
||||
"relora_warmup_steps": 10,
|
||||
"relora_anneal_steps": 10,
|
||||
"relora": True,
|
||||
"jagged_restart_steps": 50,
|
||||
"jagged_restart_warmup_steps": 10,
|
||||
"jagged_restart_anneal_steps": 10,
|
||||
"relora_prune_ratio": 0.9,
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
@@ -66,19 +65,19 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (
|
||||
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
||||
).exists(), "Relora model checkpoint not found"
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
|
||||
"Relora model checkpoint not found"
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
|
||||
80
tests/e2e/test_activation_offloading.py
Normal file
80
tests/e2e/test_activation_offloading.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
E2E tests for activation offloading
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists
|
||||
|
||||
|
||||
class TestActivationOffloading:
|
||||
"""
|
||||
E2E test cases for activation offloading
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"adapter",
|
||||
["lora", "qlora", None],
|
||||
)
|
||||
def test_activation_offloading(
|
||||
self,
|
||||
temp_dir,
|
||||
adapter,
|
||||
):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
"eos_token": "<|im_end|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "chatml",
|
||||
"path": "mlabonne/FineTome-100k",
|
||||
"type": "chat_template",
|
||||
"split": "train[:10%]",
|
||||
"field_messages": "conversations",
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"gradient_checkpointing": True,
|
||||
"activation_offloading": True,
|
||||
"save_first_step": False,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
}
|
||||
)
|
||||
if adapter == "lora":
|
||||
cfg["adapter"] = "lora"
|
||||
if adapter == "qlora":
|
||||
cfg["adapter"] = "qlora"
|
||||
cfg["load_in_4bit"] = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -26,7 +25,6 @@ class TestDeepseekV3:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
@@ -68,12 +66,12 @@ class TestDeepseekV3:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -84,7 +82,6 @@ class TestDeepseekV3:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
@@ -118,12 +115,12 @@ class TestDeepseekV3:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
139
tests/e2e/test_diffusion.py
Normal file
139
tests/e2e/test_diffusion.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""E2E smoke test for diffusion training plugin."""
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
|
||||
|
||||
class TestDiffusion:
|
||||
"""Test case for diffusion training plugin."""
|
||||
|
||||
def test_diffusion_smoke_test(self, temp_dir):
|
||||
"""
|
||||
Smoke test for diffusion training to ensure the plugin loads and trains without
|
||||
error.
|
||||
"""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 256,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 3,
|
||||
# Diffusion-specific config
|
||||
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
|
||||
"diffusion": {
|
||||
# sample generation
|
||||
"generate_samples": True,
|
||||
"generation_interval": 1,
|
||||
"num_generation_samples": 1,
|
||||
"generation_steps": 2,
|
||||
"generation_max_length": 32,
|
||||
"generation_temperature": 0.0,
|
||||
# training-specific
|
||||
"mask_token_id": 16,
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_diffusion_sft_labels(self, temp_dir):
|
||||
"""Test that diffusion training properly handles SFT data with labels."""
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 256,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.0001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
"logging_steps": 1,
|
||||
"eval_steps": 2,
|
||||
# Diffusion-specific config
|
||||
"plugins": ["axolotl.integrations.diffusion.DiffusionPlugin"],
|
||||
"diffusion": {
|
||||
# sample generation
|
||||
"generate_samples": True,
|
||||
"generation_interval": 1,
|
||||
"num_generation_samples": 1,
|
||||
"generation_steps": 2,
|
||||
"generation_max_length": 32,
|
||||
"generation_temperature": 0.0,
|
||||
# training-specific
|
||||
"mask_token_id": 16,
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": True,
|
||||
},
|
||||
# Ensure we have proper SFT labels
|
||||
"train_on_inputs": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
# Verify that the dataset has labels
|
||||
sample = dataset_meta.train_dataset[0]
|
||||
assert "labels" in sample, "SFT dataset should have labels"
|
||||
|
||||
# Check that some labels are -100 (prompt tokens)
|
||||
labels = sample["labels"]
|
||||
if hasattr(labels, "tolist"):
|
||||
labels = labels.tolist()
|
||||
assert -100 in labels, "SFT dataset should have -100 labels for prompt tokens"
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -1,6 +1,4 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
"""E2E tests for lora llama"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
@@ -23,7 +21,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -58,6 +55,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -71,7 +69,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -107,6 +104,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -120,7 +118,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -156,6 +153,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -170,7 +168,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -205,6 +202,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -218,7 +216,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ipo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -253,6 +250,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -266,7 +264,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -304,6 +301,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -318,7 +316,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="Fix the implementation")
|
||||
@with_temp_dir
|
||||
def test_kto_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -372,6 +369,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for llama pretrain
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr_scale(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -49,13 +47,13 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -66,7 +64,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -95,12 +92,12 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -13,7 +13,6 @@ class TestE2eEvaluate:
|
||||
"""Test cases for evaluate CLI"""
|
||||
|
||||
def test_evaluate(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -36,6 +35,7 @@ class TestE2eEvaluate:
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -61,13 +59,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -75,7 +73,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -117,13 +114,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -131,7 +128,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -159,13 +155,13 @@ class TestFalcon(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestGemma2:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_gemma2(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||
@@ -69,8 +67,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -80,7 +77,6 @@ class TestGemma2:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_gemma2(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||
@@ -121,8 +117,7 @@ class TestGemma2:
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -6,7 +6,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestGemma3Text:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_gemma3_text(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||
@@ -64,12 +62,12 @@ class TestGemma3Text:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -79,7 +77,6 @@ class TestGemma3Text:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_gemma3_text(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||
@@ -115,12 +112,12 @@ class TestGemma3Text:
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -11,11 +11,7 @@ class TestImports(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_import_causal_trainer(self):
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFCausalTrainerBuilder,
|
||||
)
|
||||
pass
|
||||
|
||||
def test_import_rl_trainer(self):
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFRLTrainerBuilder,
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -17,7 +16,6 @@ class TestLlama:
|
||||
"""
|
||||
|
||||
def test_fft_trust_remote_code(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -46,19 +44,18 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -94,19 +91,18 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -139,19 +135,18 @@ class TestLlama:
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_batch_flattening(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -180,13 +175,13 @@ class TestLlama:
|
||||
"batch_flattening": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
E2E tests for llama pretrain
|
||||
"""
|
||||
"""E2E tests for llama pretrain"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,23 +11,17 @@ from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestPretrainLlama:
|
||||
"""
|
||||
Test case for Llama models w pretraining
|
||||
"""
|
||||
"""Test case for Llama models w pretraining"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"pretrain_multipack_attn",
|
||||
[True, False],
|
||||
("sample_packing", "pretrain_multipack_attn"),
|
||||
[
|
||||
(False, False),
|
||||
(True, True),
|
||||
(True, False),
|
||||
],
|
||||
)
|
||||
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
|
||||
if not sample_packing and pretrain_multipack_attn:
|
||||
return
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -38,7 +29,7 @@ class TestPretrainLlama:
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"pretrain_multipack_attn": pretrain_multipack_attn,
|
||||
"dataset_processes": 1,
|
||||
"dataset_num_proc": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
@@ -61,22 +52,22 @@ class TestPretrainLlama:
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
loss_threshold = 3.5
|
||||
loss_threshold = 3.6
|
||||
if sample_packing and not pretrain_multipack_attn:
|
||||
loss_threshold = 6.5
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss is too high",
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestLlamaVision(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
@@ -55,20 +53,19 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
@@ -102,12 +99,12 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -52,15 +52,15 @@ class TestLoadModelUtils:
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"tensor_parallel_size": 1,
|
||||
"context_parallel_size": 1,
|
||||
}
|
||||
)
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader = ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
|
||||
@@ -72,7 +72,7 @@ class TestLoadModelUtils:
|
||||
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg)
|
||||
self.model_loader.load()
|
||||
self.model_loader._convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -50,13 +48,13 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestMamba(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "state-spaces/mamba-130m",
|
||||
@@ -52,13 +50,13 @@ class TestMamba(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": None,
|
||||
"save_safetensors": False,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -22,7 +21,6 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -56,20 +54,19 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -97,6 +94,7 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -106,8 +104,7 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -7,7 +7,6 @@ import unittest
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -23,7 +22,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_w_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -62,13 +60,13 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -79,7 +77,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_wo_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -118,13 +115,13 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -135,7 +132,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_w_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -173,6 +169,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -182,8 +179,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -194,7 +190,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_wo_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -232,6 +227,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -241,8 +237,7 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (
|
||||
@@ -253,7 +248,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -278,6 +272,7 @@ class TestMixtral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -287,8 +282,7 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom optimizers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -14,6 +13,7 @@ from .utils import (
|
||||
check_model_output_exists,
|
||||
require_torch_2_5_1,
|
||||
require_torch_2_6_0,
|
||||
require_torch_2_7_0,
|
||||
with_temp_dir,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_optimi_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -56,13 +55,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"optimizer": "optimi_adamw",
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -71,7 +70,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_adopt_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -102,13 +100,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adopt_adamw",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -117,7 +115,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_muon(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -149,21 +146,62 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"optimizer": "muon",
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Muon" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_7_0
|
||||
def test_dion(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "dion",
|
||||
"dion_lr": 0.01,
|
||||
"dion_momentum": 0.95,
|
||||
"lr_scheduler": "cosine",
|
||||
"weight_decay": 0.01,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
_, _, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
assert "Dion" in trainer.optimizer.optimizer.__class__.__name__
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -188,14 +226,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "constant",
|
||||
"save_safetensors": True,
|
||||
"max_steps": 10,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -203,7 +240,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_6_0
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
@@ -237,13 +273,13 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"adam_epsilon2": 1e-16,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import unittest
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -22,7 +21,6 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_loss_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -49,6 +47,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -58,11 +57,10 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestPhi(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -54,19 +52,18 @@ class TestPhi(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -104,12 +101,12 @@ class TestPhi(unittest.TestCase):
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
58
tests/e2e/test_preprocess.py
Normal file
58
tests/e2e/test_preprocess.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""E2E Test the preprocess cli"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class TestPreprocess:
|
||||
"""test cases for preprocess"""
|
||||
|
||||
def test_w_deepspeed(self, temp_dir):
|
||||
"""make sure preproces doesn't choke when using deepspeed in the config"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"deepspeed": str(AXOLOTL_ROOT / "deepspeed_configs/zero1.json"),
|
||||
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"preprocess",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
assert (Path(temp_dir) / "last_run_prepared").exists()
|
||||
@@ -4,7 +4,6 @@ E2E tests for process reward model w/ lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_prm(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -50,12 +48,12 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||
"seed": 42,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
|
||||
113
tests/e2e/test_profiler.py
Normal file
113
tests/e2e/test_profiler.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
e2e gpu test for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="profiler_base_cfg")
|
||||
def fixture_profiler_base_cfg():
|
||||
cfg = DictDefault(
|
||||
base_model="HuggingFaceTB/SmolLM2-135M",
|
||||
tokenizer_type="AutoTokenizer",
|
||||
sequence_len=1024,
|
||||
load_in_8bit=True,
|
||||
adapter="lora",
|
||||
lora_r=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
lora_target_linear=True,
|
||||
val_set_size=0.02,
|
||||
special_tokens={"pad_token": "<|endoftext|>"},
|
||||
datasets=[
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
num_epochs=1,
|
||||
micro_batch_size=2,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=0.00001,
|
||||
optimizer="adamw_torch_fused",
|
||||
lr_scheduler="cosine",
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
class TestProfiler:
|
||||
"""
|
||||
test cases for the pytorch profiler callback
|
||||
"""
|
||||
|
||||
def test_profiler_saves(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_saves_w_start(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=1,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"profiler_steps_start",
|
||||
[3, 5],
|
||||
)
|
||||
def test_profiler_saves_past_end(
|
||||
self, profiler_base_cfg, temp_dir, profiler_steps_start
|
||||
):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=profiler_steps_start,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
|
||||
def test_profiler_never_started(self, profiler_base_cfg, temp_dir):
|
||||
cfg = profiler_base_cfg | DictDefault(
|
||||
output_dir=temp_dir,
|
||||
max_steps=5,
|
||||
profiler_steps=3,
|
||||
profiler_steps_start=6,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
assert not (Path(temp_dir) / "snapshot.pickle").exists()
|
||||
@@ -2,26 +2,22 @@
|
||||
E2E tests for QAT
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestQATLlama(unittest.TestCase):
|
||||
class TestQATLlama:
|
||||
"""
|
||||
Test case for QAT Llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_qat_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
def test_qat(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -47,7 +43,7 @@ class TestQATLlama(unittest.TestCase):
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int8",
|
||||
"weight_dtype": "int4",
|
||||
"group_size": 8,
|
||||
},
|
||||
"num_epochs": 1,
|
||||
@@ -60,12 +56,78 @@ class TestQATLlama(unittest.TestCase):
|
||||
"max_steps": 5,
|
||||
"save_safetensors": True,
|
||||
"bf16": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
def test_qat_dpo(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"use_tensorboard": True,
|
||||
"bf16": True,
|
||||
"qat": {
|
||||
"quantize_embedding": True,
|
||||
"activation_dtype": "int8",
|
||||
"weight_dtype": "int4",
|
||||
"group_size": 8,
|
||||
},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_preference_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
loss_threshold = 2.3
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
loss_threshold,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@@ -5,42 +5,41 @@ Tests for axolotl.utils.quantization
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
|
||||
from torchao.quantization.granularity import PerAxis, PerGroup
|
||||
from torchao.quantization.linear_activation_quantized_tensor import (
|
||||
LinearActivationQuantizedTensor,
|
||||
)
|
||||
from torchao.quantization import LinearActivationQuantizedTensor
|
||||
from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from torchao.quantization.quant_api import (
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
Int4WeightOnlyConfig,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
Int8WeightOnlyConfig,
|
||||
UIntXWeightOnlyConfig,
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
)
|
||||
from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.trainer_callback import TrainerState
|
||||
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.quantization import (
|
||||
convert_qat_model_for_ptq,
|
||||
get_ptq_config,
|
||||
convert_qat_model,
|
||||
get_quantization_config,
|
||||
prepare_model_for_qat,
|
||||
quantize_model_for_ptq,
|
||||
quantize_model,
|
||||
)
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
from axolotl.utils.schemas.enums import TorchAOQuantDType
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
from tests.e2e.utils import require_torch_2_6_0
|
||||
from tests.e2e.utils import (
|
||||
require_torch_2_8_0,
|
||||
requires_cuda_ge_8_9,
|
||||
requires_sm_ge_100,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def model():
|
||||
dummy_model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceTB/SmolLM2-135M",
|
||||
device_map="cuda",
|
||||
torch_dtype=torch.bfloat16,
|
||||
"Qwen/Qwen2-0.5B",
|
||||
device_map="auto",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
with torch.device(dummy_model.device):
|
||||
dummy_model.model.embed_tokens = torch.nn.Embedding(
|
||||
@@ -48,45 +47,56 @@ def model():
|
||||
dummy_model.model.embed_tokens.weight.shape[1],
|
||||
dtype=dummy_model.model.embed_tokens.weight.dtype,
|
||||
)
|
||||
return dummy_model
|
||||
yield dummy_model
|
||||
del dummy_model
|
||||
|
||||
|
||||
ptq_config_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
# weight_dtype, activation_dtype, group_size, expected_type
|
||||
(
|
||||
TorchIntDType.uint4,
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.int8,
|
||||
None,
|
||||
None,
|
||||
UIntXWeightOnlyConfig,
|
||||
{"dtype": torch.uint4, "group_size": None},
|
||||
),
|
||||
(TorchIntDType.int8, None, 32, Int8WeightOnlyConfig, {"group_size": 32}),
|
||||
(TorchIntDType.int4, None, 4, Int4WeightOnlyConfig, {"group_size": 4}),
|
||||
(
|
||||
TorchIntDType.int4,
|
||||
TorchIntDType.int4,
|
||||
None,
|
||||
Int4DynamicActivationInt4WeightConfig,
|
||||
{},
|
||||
Int8DynamicActivationInt4WeightConfig,
|
||||
),
|
||||
(
|
||||
TorchIntDType.int8,
|
||||
TorchIntDType.int8,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
Int8DynamicActivationInt8WeightConfig,
|
||||
{},
|
||||
Float8DynamicActivationFloat8WeightConfig,
|
||||
),
|
||||
(
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
Float8DynamicActivationInt4WeightConfig,
|
||||
),
|
||||
]
|
||||
|
||||
ptq_test_cases = [
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception
|
||||
(TorchIntDType.int8, None, 8, False, None),
|
||||
(TorchIntDType.int4, None, 4, True, None),
|
||||
(TorchIntDType.uint4, None, 8, False, None),
|
||||
(TorchIntDType.int4, TorchIntDType.int4, 8, False, None),
|
||||
(TorchIntDType.int8, TorchIntDType.int8, 8, True, None),
|
||||
(TorchIntDType.int8, None, None, False, ValueError),
|
||||
(TorchIntDType.int4, None, None, False, ValueError),
|
||||
# weight_dtype, activation_dtype, group_size, quantize_embedding, expected_exception, expected_tensor_class
|
||||
(TorchAOQuantDType.int4, None, 4, True, None, Int4Tensor),
|
||||
(
|
||||
TorchAOQuantDType.int4,
|
||||
TorchAOQuantDType.int8,
|
||||
8,
|
||||
False,
|
||||
None,
|
||||
LinearActivationQuantizedTensor,
|
||||
),
|
||||
# (
|
||||
# TorchAOQuantDType.int4,
|
||||
# TorchAOQuantDType.float8_e4m3fn,
|
||||
# None,
|
||||
# False,
|
||||
# None,
|
||||
# Int4Tensor,
|
||||
# ),
|
||||
(TorchAOQuantDType.int4, None, None, False, None, Int4Tensor),
|
||||
# Deprecated configs
|
||||
(TorchAOQuantDType.int8, None, 8, False, ValueError, None),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int4, 8, False, ValueError, None),
|
||||
(TorchAOQuantDType.int8, TorchAOQuantDType.int8, 8, True, ValueError, None),
|
||||
]
|
||||
|
||||
|
||||
@@ -96,44 +106,132 @@ class TestQuantization:
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,expected_type,expected_params",
|
||||
"weight_dtype,activation_dtype,group_size,expected_type",
|
||||
ptq_config_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_get_ptq_config(
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type, expected_params
|
||||
self, weight_dtype, activation_dtype, group_size, expected_type
|
||||
):
|
||||
config = get_ptq_config(weight_dtype, activation_dtype, group_size)
|
||||
|
||||
config = get_quantization_config(weight_dtype, activation_dtype, group_size)
|
||||
assert isinstance(config, expected_type)
|
||||
|
||||
for param_name, param_value in expected_params.items():
|
||||
if isinstance(param_value, (PerAxis, PerGroup)):
|
||||
if isinstance(param_value, PerAxis):
|
||||
assert isinstance(getattr(config, param_name), PerAxis)
|
||||
assert getattr(config, param_name).axis == param_value.axis
|
||||
else:
|
||||
assert isinstance(getattr(config, param_name), PerGroup)
|
||||
assert (
|
||||
getattr(config, param_name).group_size == param_value.group_size
|
||||
)
|
||||
else:
|
||||
assert getattr(config, param_name) == param_value
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_get_ptq_config_int4_weight_only(self):
|
||||
from torchao.quantization.quant_api import Int4WeightOnlyConfig
|
||||
|
||||
config = get_quantization_config(TorchAOQuantDType.int4, None, 4)
|
||||
assert isinstance(config, Int4WeightOnlyConfig)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype", [TorchIntDType.int8, TorchIntDType.int4, TorchIntDType.uint4]
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception,expected_tensor_class",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@requires_cuda_ge_8_9
|
||||
@require_torch_2_8_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
expected_tensor_class,
|
||||
):
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, expected_tensor_class
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, expected_tensor_class)
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_quantize_model_for_ptq_fp8(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
from torchao.quantization.quantize_.workflows.float8.float8_tensor import (
|
||||
Float8Tensor,
|
||||
QuantizeTensorToFloat8Kwargs,
|
||||
)
|
||||
|
||||
quantize_model(
|
||||
model,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
)
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, Float8Tensor)
|
||||
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||
child.weight.act_quant_kwargs, QuantizeTensorToFloat8Kwargs
|
||||
)
|
||||
|
||||
@require_torch_2_8_0
|
||||
@requires_sm_ge_100
|
||||
def test_quantize_model_for_ptq_nvfp4(
|
||||
self,
|
||||
model,
|
||||
):
|
||||
from torchao.prototype.mx_formats.nvfp4_tensor import (
|
||||
NVFP4Tensor,
|
||||
QuantizeTensorToNVFP4Kwargs,
|
||||
)
|
||||
|
||||
quantize_model(model, TorchAOQuantDType.nvfp4, 16, TorchAOQuantDType.nvfp4)
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child.weight, NVFP4Tensor)
|
||||
assert child.weight.act_quant_kwargs is not None and isinstance(
|
||||
child.weight.act_quant_kwargs, QuantizeTensorToNVFP4Kwargs
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"activation_dtype", [None, TorchIntDType.int4, TorchIntDType.int8]
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding",
|
||||
[
|
||||
(TorchAOQuantDType.int4, None, 8, False),
|
||||
(TorchAOQuantDType.int4, None, 16, True),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 8, False),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.int8, 16, True),
|
||||
(
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
TorchAOQuantDType.float8_e4m3fn,
|
||||
None,
|
||||
False,
|
||||
),
|
||||
(TorchAOQuantDType.int4, TorchAOQuantDType.float8_e4m3fn, None, True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("group_size", [4, 8])
|
||||
@pytest.mark.parametrize("quantize_embedding", [False, True])
|
||||
@require_torch_2_6_0
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
def test_prepare_model_for_qat(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
prepare_model_for_qat(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
@@ -142,17 +240,19 @@ class TestQuantization:
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.dtype
|
||||
== weight_dtype.value
|
||||
)
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
if group_size:
|
||||
assert (
|
||||
model.model.embed_tokens.weight_fake_quantizer.config.group_size
|
||||
== group_size
|
||||
)
|
||||
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
assert isinstance(child, FakeQuantizedLinear)
|
||||
assert hasattr(child, "weight_fake_quantizer")
|
||||
assert child.weight_fake_quantizer.config.dtype == weight_dtype.value
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if group_size:
|
||||
assert child.weight_fake_quantizer.config.group_size == group_size
|
||||
if activation_dtype:
|
||||
assert hasattr(child, "activation_fake_quantizer")
|
||||
assert (
|
||||
@@ -162,47 +262,40 @@ class TestQuantization:
|
||||
else:
|
||||
assert child.activation_fake_quantizer is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"weight_dtype,activation_dtype,group_size,quantize_embedding,expected_exception",
|
||||
ptq_test_cases,
|
||||
)
|
||||
@require_torch_2_6_0
|
||||
def test_quantize_model_for_ptq(
|
||||
self,
|
||||
model,
|
||||
weight_dtype,
|
||||
activation_dtype,
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model_for_ptq(
|
||||
model,
|
||||
weight_dtype,
|
||||
group_size,
|
||||
activation_dtype,
|
||||
quantize_embedding,
|
||||
)
|
||||
else:
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
if quantize_embedding:
|
||||
assert isinstance(
|
||||
model.model.embed_tokens.weight, AffineQuantizedTensor
|
||||
), "Embedding weight should be quantized"
|
||||
for child in list(model.children()):
|
||||
if isinstance(child, torch.nn.Linear):
|
||||
if activation_dtype:
|
||||
assert isinstance(
|
||||
child.weight, LinearActivationQuantizedTensor
|
||||
), "Linear weight should be quantized with activation quantization"
|
||||
else:
|
||||
assert isinstance(
|
||||
child.weight, AffineQuantizedTensor
|
||||
), "Linear weight should be quantized without activation quantization"
|
||||
@require_torch_2_8_0
|
||||
@requires_cuda_ge_8_9
|
||||
def test_convert_qat_model(self, model):
|
||||
config = QATConfig(
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model(
|
||||
model,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
|
||||
|
||||
class TestQuantizationCallback:
|
||||
@@ -216,12 +309,10 @@ class TestQuantizationCallback:
|
||||
global_step=0,
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
@@ -268,12 +359,10 @@ class TestQuantizationCallback:
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
@require_torch_2_8_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
weight_dtype="int4",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
@@ -306,45 +395,3 @@ class TestQuantizationCallback:
|
||||
# quantization should be enabled from the get-go
|
||||
assert model.model.embed_tokens.weight_fake_quantizer.enabled
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
|
||||
class TestConvertQATModelForPTQ:
|
||||
"""
|
||||
Test convert_qat_model_for_ptq
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_convert_qat_model_for_ptq(
|
||||
self, model
|
||||
): # pylint: disable=redefined-outer-name
|
||||
config = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
group_size=8,
|
||||
quantize_embedding=True,
|
||||
)
|
||||
|
||||
# quantize model for qat
|
||||
prepare_model_for_qat(
|
||||
model,
|
||||
config.weight_dtype,
|
||||
config.group_size,
|
||||
config.activation_dtype,
|
||||
config.quantize_embedding,
|
||||
)
|
||||
|
||||
assert isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# apply conversion
|
||||
convert_qat_model_for_ptq(
|
||||
model,
|
||||
quantize_embedding=config.quantize_embedding,
|
||||
)
|
||||
# ensure modules have been swapped out
|
||||
assert not isinstance(model.model.embed_tokens, FakeQuantizedEmbedding)
|
||||
assert not isinstance(model.lm_head, FakeQuantizedLinear)
|
||||
|
||||
# ensure weights have been quantized
|
||||
assert isinstance(model.model.embed_tokens.weight, nn.Parameter)
|
||||
assert isinstance(model.lm_head.weight, nn.Parameter)
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestE2eQwen:
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
@@ -59,6 +58,7 @@ class TestE2eQwen:
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
"gradient_checkpointing": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for reward model lora llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_rm_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -59,15 +57,15 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"gradient_checkpointing": True,
|
||||
"warmup_ratio": 0.1,
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
|
||||
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss (%s) is too high"
|
||||
)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
100
tests/e2e/test_save_first_step.py
Normal file
100
tests/e2e/test_save_first_step.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
E2E tests for relora llama
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestSaveFirstStepCallback(unittest.TestCase):
|
||||
"""Test cases for save_first_step callback config."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_save_first_step(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_no_save_first_step(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
with pytest.raises(AssertionError):
|
||||
check_model_output_exists(str(Path(temp_dir) / "checkpoint-1"), cfg)
|
||||
@@ -4,7 +4,6 @@ E2E tests for custom schedulers using Llama
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
@@ -20,7 +19,6 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_rex_scheduler(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -52,13 +50,13 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
"lr_scheduler": "rex",
|
||||
"warmup_steps": 5,
|
||||
"cosine_min_lr_ratio": 0.05,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
73
tests/e2e/test_streaming.py
Normal file
73
tests/e2e/test_streaming.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""E2E tests for streaming dataset functionality"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
class TestStreamingDatasets:
|
||||
"""Test case for streaming datasets"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_packing",
|
||||
[True, False],
|
||||
)
|
||||
def test_streaming_dataset(self, temp_dir, sample_packing):
|
||||
"""Test streaming datasets"""
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": sample_packing,
|
||||
"pretrain_multipack_attn": sample_packing,
|
||||
"streaming_multipack_buffer_size": 10000,
|
||||
"dataset_processes": 1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
# Streaming config
|
||||
"streaming": True,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"val_set_size": 0.0,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"use_tensorboard": True,
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
# Verify training actually happened by checking loss decrease
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
3.0,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
63
tests/e2e/test_tokenizer.py
Normal file
63
tests/e2e/test_tokenizer.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
e2e test for saving the tokenizer
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import check_model_output_exists
|
||||
|
||||
|
||||
def test_tokenizer_no_save_jinja_files(temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_first_step": False,
|
||||
"fp16": False,
|
||||
"tokenizer_save_jinja_files": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
with patch("axolotl.train.execute_training"):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
tokenizer_config = f.read()
|
||||
assert "chat_template" in tokenizer_config
|
||||
@@ -2,6 +2,7 @@
|
||||
helper utils for tests
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -77,6 +78,30 @@ def require_torch_2_6_0(test_case):
|
||||
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_7_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.7.0
|
||||
"""
|
||||
|
||||
def is_min_2_7_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.7.0")
|
||||
|
||||
return unittest.skipUnless(is_min_2_7_0(), "test requires torch>=2.7.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_8_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.7.0
|
||||
"""
|
||||
|
||||
def is_min_2_8_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.8.0")
|
||||
|
||||
return unittest.skipUnless(is_min_2_8_0(), "test requires torch>=2.8.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch < 2.6.0
|
||||
@@ -95,12 +120,7 @@ def require_vllm(test_case):
|
||||
"""
|
||||
|
||||
def is_vllm_installed():
|
||||
try:
|
||||
import vllm # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
return importlib.util.find_spec("vllm") is not None
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_vllm_installed(), "test requires vllm to be installed"
|
||||
@@ -113,25 +133,46 @@ def require_llmcompressor(test_case):
|
||||
"""
|
||||
|
||||
def is_llmcompressor_installed():
|
||||
try:
|
||||
import llmcompressor # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
return importlib.util.find_spec("llmcompressor") is not None
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def requires_sm_ge_100(test_case):
|
||||
is_sm_ge_100 = (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and torch.cuda.get_device_capability() >= (10, 0)
|
||||
)
|
||||
return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
|
||||
|
||||
|
||||
def requires_cuda_ge_8_9(test_case):
|
||||
is_cuda_ge_8_9 = (
|
||||
torch.cuda.is_available()
|
||||
and torch.version.cuda
|
||||
and torch.cuda.get_device_capability() >= (8, 9)
|
||||
)
|
||||
return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
|
||||
def require_hopper(test_case):
|
||||
return unittest.skipUnless(is_hopper(), "test requires h100/hopper GPU")(test_case)
|
||||
|
||||
|
||||
def check_tensorboard(
|
||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||
temp_run_dir: str,
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
@@ -139,8 +180,9 @@ def check_tensorboard(
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
df = reader.scalars
|
||||
df = df[(df.tag == tag)]
|
||||
lt_val = (1 + rtol) * lt_val
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
else:
|
||||
|
||||
@@ -20,7 +20,7 @@ def reload_modules(hf_hub_offline):
|
||||
importlib.reload(huggingface_hub.constants)
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||
importlib.reload(datasets.config)
|
||||
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
||||
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
|
||||
reset_sessions()
|
||||
|
||||
|
||||
|
||||
274
tests/integrations/test_diffusion.py
Normal file
274
tests/integrations/test_diffusion.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Tests for diffusion trainer integration."""
|
||||
|
||||
# pylint: disable=redefined-outer-name,protected-access
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.diffusion import DiffusionTrainer
|
||||
from axolotl.integrations.diffusion.utils import create_bidirectional_attention_mask
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
"""Create a mock tokenizer."""
|
||||
tokenizer = Mock()
|
||||
tokenizer.bos_token_id = 1
|
||||
tokenizer.eos_token_id = 2
|
||||
tokenizer.pad_token_id = 0
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusion_config():
|
||||
"""Create a diffusion config."""
|
||||
return DictDefault(
|
||||
{
|
||||
"diffusion": {
|
||||
"mask_token_id": 32000,
|
||||
"eps": 1e-3,
|
||||
"importance_weighting": False,
|
||||
},
|
||||
"sample_packing": False,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusion_trainer_instance(mock_tokenizer, diffusion_config):
|
||||
"""Create a diffusion trainer instance for testing methods directly."""
|
||||
# Create a minimal trainer instance just for testing methods
|
||||
trainer = object.__new__(DiffusionTrainer) # Bypass __init__
|
||||
trainer.cfg = diffusion_config
|
||||
trainer._special_token_ids = {0, 1, 2} # pad, bos, eos
|
||||
trainer.processing_class = mock_tokenizer
|
||||
trainer.store_metrics = Mock() # Mock metrics storage
|
||||
return trainer
|
||||
|
||||
|
||||
class TestDiffusionTrainer:
|
||||
"""Test the DiffusionTrainer class."""
|
||||
|
||||
def test_forward_process_basic(self, diffusion_trainer_instance):
|
||||
"""Test basic forward process without labels."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(input_ids, eps=0.1)
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert noisy_batch.shape == input_ids.shape
|
||||
assert masked_indices.shape == input_ids.shape
|
||||
assert p_mask.shape == input_ids.shape
|
||||
|
||||
# Check that special tokens are not masked
|
||||
special_token_positions = (input_ids == 1) | (input_ids == 2) | (input_ids == 0)
|
||||
assert not masked_indices[special_token_positions].any()
|
||||
|
||||
# Check that mask token is applied
|
||||
mask_token_id = diffusion_trainer_instance.cfg.diffusion.mask_token_id
|
||||
masked_positions = masked_indices
|
||||
if masked_positions.any():
|
||||
assert (noisy_batch[masked_positions] == mask_token_id).all()
|
||||
|
||||
def test_forward_process_with_labels(self, diffusion_trainer_instance):
|
||||
"""Test forward process with SFT labels."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
noisy_batch, masked_indices, p_mask = (
|
||||
diffusion_trainer_instance._forward_process(
|
||||
input_ids, labels=labels, eps=0.1
|
||||
)
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert noisy_batch.shape == input_ids.shape
|
||||
assert masked_indices.shape == input_ids.shape
|
||||
assert p_mask.shape == input_ids.shape
|
||||
|
||||
# Check that only answer tokens can be masked (where labels != -100)
|
||||
non_answer_mask = labels == -100
|
||||
|
||||
# No masking should occur on non-answer tokens
|
||||
assert not masked_indices[non_answer_mask].any()
|
||||
|
||||
# p_mask should be the same for all positions (sampled timestep),
|
||||
# but masking is only applied to answer tokens
|
||||
assert p_mask.shape == input_ids.shape
|
||||
# Verify that masked_indices respects the answer mask
|
||||
assert not masked_indices[non_answer_mask].any()
|
||||
|
||||
def test_forward_process_with_attention_mask(self, diffusion_trainer_instance):
|
||||
"""Test forward process with attention mask."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 0]], dtype=torch.long)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 0]], dtype=torch.long)
|
||||
|
||||
_, masked_indices, p_mask = diffusion_trainer_instance._forward_process(
|
||||
input_ids, attention_mask=attention_mask, eps=0.1
|
||||
)
|
||||
|
||||
# Check that padding tokens are not masked
|
||||
padding_positions = attention_mask == 0
|
||||
assert not masked_indices[padding_positions].any()
|
||||
assert (p_mask[padding_positions] == 0).all()
|
||||
|
||||
def test_bidirectional_attention_mask_no_packing(self, diffusion_trainer_instance):
|
||||
"""Test bidirectional attention mask without sample packing."""
|
||||
input_ids = torch.tensor([[1, 10, 20, 2]], dtype=torch.long)
|
||||
|
||||
mask = create_bidirectional_attention_mask(input_ids)
|
||||
|
||||
# Should be all-to-all attention
|
||||
expected_shape = (1, 1, 4, 4)
|
||||
assert mask.shape == expected_shape
|
||||
assert mask.all()
|
||||
|
||||
def test_bidirectional_attention_mask_with_packing(
|
||||
self, diffusion_trainer_instance
|
||||
):
|
||||
"""Test bidirectional attention mask with sample packing."""
|
||||
diffusion_trainer_instance.cfg.sample_packing = True
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 40, 2]], dtype=torch.long)
|
||||
# Sample IDs: first sample (1), second sample (2)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2]], dtype=torch.long)
|
||||
|
||||
mask = create_bidirectional_attention_mask(
|
||||
input_ids, attention_mask, sample_packing=True
|
||||
)
|
||||
|
||||
# Check that tokens within same sample can attend to each other
|
||||
# but not across samples
|
||||
assert mask[0, 0, 0, 1].item() # First sample tokens can attend to each other
|
||||
assert mask[0, 0, 1, 2].item()
|
||||
assert not mask[0, 0, 0, 3].item() # Can't attend across samples
|
||||
assert not mask[0, 0, 2, 4].item()
|
||||
assert mask[0, 0, 3, 4].item() # Second sample tokens can attend to each other
|
||||
|
||||
def test_compute_loss_basic(self, diffusion_trainer_instance):
|
||||
"""Test basic loss computation."""
|
||||
# Mock model that returns logits
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
loss, outputs = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
assert outputs == mock_outputs
|
||||
|
||||
# Check that metrics were stored
|
||||
diffusion_trainer_instance.store_metrics.assert_called_once()
|
||||
|
||||
def test_compute_loss_sft(self, diffusion_trainer_instance):
|
||||
"""Test loss computation with SFT labels."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 5
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size, requires_grad=True)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
diffusion_trainer_instance.cfg.datasets = Mock()
|
||||
|
||||
input_ids = torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long)
|
||||
labels = torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long)
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids, labels=labels
|
||||
)
|
||||
|
||||
# Check that loss is computed
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.requires_grad
|
||||
|
||||
# Check that SFT metrics were added
|
||||
call_args = diffusion_trainer_instance.store_metrics.call_args[0][0]
|
||||
assert "answer_ratio" in call_args
|
||||
assert "avg_answer_length" in call_args
|
||||
|
||||
def test_compute_loss_no_masked_tokens(self, diffusion_trainer_instance):
|
||||
"""Test loss computation when no tokens are masked."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
vocab_size = 1000
|
||||
seq_len = 3
|
||||
mock_outputs.logits = torch.randn(1, seq_len, vocab_size)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
# Only special tokens (which won't be masked)
|
||||
input_ids = torch.tensor([[1, 0, 2]], dtype=torch.long)
|
||||
|
||||
loss, _ = diffusion_trainer_instance._compute_diffusion_loss(
|
||||
mock_model, input_ids
|
||||
)
|
||||
|
||||
# Loss should be zero when no tokens are masked
|
||||
assert loss.item() == 0.0
|
||||
assert loss.requires_grad
|
||||
|
||||
def test_cache_special_token_ids(self, mock_tokenizer):
|
||||
"""Test caching of special token IDs."""
|
||||
trainer = object.__new__(DiffusionTrainer)
|
||||
trainer.processing_class = mock_tokenizer
|
||||
trainer._cache_special_token_ids()
|
||||
assert trainer._special_token_ids == {0, 1, 2}
|
||||
|
||||
def test_cache_special_token_ids_no_tokenizer(self):
|
||||
"""Test caching when no tokenizer is available."""
|
||||
trainer = object.__new__(DiffusionTrainer)
|
||||
trainer.processing_class = None
|
||||
trainer._cache_special_token_ids()
|
||||
|
||||
assert trainer._special_token_ids == set()
|
||||
|
||||
def test_main_compute_loss_interface(self, diffusion_trainer_instance):
|
||||
"""Test the main compute_loss interface."""
|
||||
# Mock model
|
||||
mock_model = Mock()
|
||||
mock_outputs = Mock()
|
||||
mock_outputs.logits = torch.randn(1, 5, 1000)
|
||||
mock_model.return_value = mock_outputs
|
||||
mock_model.training = True
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.tensor([[1, 10, 20, 30, 2]], dtype=torch.long),
|
||||
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.long),
|
||||
"labels": torch.tensor([[-100, -100, 20, 30, 2]], dtype=torch.long),
|
||||
}
|
||||
|
||||
# Test without return_outputs
|
||||
loss = diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test with return_outputs
|
||||
loss, outputs = diffusion_trainer_instance.compute_loss(
|
||||
mock_model, inputs, return_outputs=True
|
||||
)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert outputs == mock_outputs
|
||||
|
||||
def test_missing_input_ids_raises_error(self, diffusion_trainer_instance):
|
||||
"""Test that missing input_ids raises ValueError."""
|
||||
mock_model = Mock()
|
||||
inputs = {"attention_mask": torch.tensor([[1, 1, 1]])}
|
||||
|
||||
with pytest.raises(ValueError, match="input_ids is required"):
|
||||
diffusion_trainer_instance.compute_loss(mock_model, inputs)
|
||||
92
tests/integrations/test_diffusion_callback.py
Normal file
92
tests/integrations/test_diffusion_callback.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for diffusion generation callback dataloader selection and triggering."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.integrations.diffusion import DiffusionGenerationCallback
|
||||
|
||||
|
||||
class DummyTrainer:
|
||||
"""Minimal trainer double with required attributes/methods for the callback."""
|
||||
|
||||
def __init__(self, use_eval: bool):
|
||||
# Config used by callback
|
||||
self.cfg = SimpleNamespace(
|
||||
diffusion=SimpleNamespace(
|
||||
generation_interval=1,
|
||||
num_generation_samples=1,
|
||||
generation_max_length=32,
|
||||
generation_steps=4,
|
||||
generation_temperature=0.0,
|
||||
mask_token_id=16,
|
||||
),
|
||||
use_wandb=False,
|
||||
)
|
||||
|
||||
# Model/tokenizer are passed through to generate_samples; not used here
|
||||
self.model = Mock()
|
||||
self.processing_class = Mock()
|
||||
|
||||
# Datasets and loaders
|
||||
self.eval_dataset = object() if use_eval else None
|
||||
self._train_loader = object()
|
||||
self._eval_loader = object()
|
||||
|
||||
# State for world process check
|
||||
self.state = SimpleNamespace(is_world_process_zero=True)
|
||||
|
||||
# Track which loader was requested
|
||||
self.requested: list[str] = []
|
||||
|
||||
def get_train_dataloader(self):
|
||||
self.requested.append("train")
|
||||
return self._train_loader
|
||||
|
||||
def get_eval_dataloader(self):
|
||||
self.requested.append("eval")
|
||||
return self._eval_loader
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_eval", [False, True])
|
||||
def test_callback_uses_correct_dataloader(monkeypatch, use_eval):
|
||||
trainer = DummyTrainer(use_eval=use_eval)
|
||||
callback = DiffusionGenerationCallback(trainer)
|
||||
|
||||
captured = {}
|
||||
|
||||
# Patch generate_samples in the callback module's namespace
|
||||
def fake_generate_samples(**kwargs):
|
||||
captured["dataloader"] = kwargs.get("dataloader")
|
||||
# Return one dummy sample to exercise logging path
|
||||
return [
|
||||
{
|
||||
"original": "o",
|
||||
"masked": "m",
|
||||
"generated": "g",
|
||||
"mask_ratio": 0.5,
|
||||
"masked_tokens": 1,
|
||||
"total_tokens": 2,
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"axolotl.integrations.diffusion.callbacks.generate_samples",
|
||||
fake_generate_samples,
|
||||
)
|
||||
|
||||
# Trigger at step 1 (interval=1)
|
||||
args = SimpleNamespace()
|
||||
state = SimpleNamespace(global_step=1)
|
||||
control = SimpleNamespace()
|
||||
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
|
||||
# Assert the expected dataloader path was used
|
||||
if use_eval:
|
||||
assert trainer.requested[0] == "eval"
|
||||
assert captured["dataloader"] is trainer._eval_loader
|
||||
else:
|
||||
assert trainer.requested[0] == "train"
|
||||
assert captured["dataloader"] is trainer._train_loader
|
||||
@@ -10,7 +10,6 @@ from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture(name="minimal_liger_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
@@ -30,7 +29,6 @@ def fixture_cfg():
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation:
|
||||
"""
|
||||
Test the validation module for liger
|
||||
|
||||
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
35
tests/monkeypatch/test_mistral_tokenizer_patch.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Integration tests for MistralCommonTokenizer patches."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMistralTokenizerPatchIntegration:
|
||||
"""Test MistralCommonTokenizer patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_mistral_tokenizer_image_patch(self):
|
||||
"""Test that MistralCommonTokenizer image patch can be applied."""
|
||||
try:
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
except ImportError:
|
||||
pytest.skip("MistralCommonTokenizer not available")
|
||||
|
||||
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
|
||||
apply_mistral_tokenizer_image_patch,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_apply_chat_template = MistralCommonTokenizer.apply_chat_template
|
||||
|
||||
# Apply patch
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
# Verify patch was applied
|
||||
assert (
|
||||
MistralCommonTokenizer.apply_chat_template != original_apply_chat_template
|
||||
), "apply_chat_template was not patched"
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(MistralCommonTokenizer.apply_chat_template), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
77
tests/monkeypatch/test_pixtral_flash_attention_patch.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Integration tests for Pixtral Flash Attention patches."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestPixtralFlashAttentionPatchIntegration:
|
||||
"""Test Pixtral Flash Attention patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_pixtral_flash_attention_patch(self):
|
||||
"""Test that Pixtral Flash Attention patch can be applied and works correctly."""
|
||||
try:
|
||||
from transformers import modeling_flash_attention_utils
|
||||
except ImportError:
|
||||
pytest.skip("Flash Attention utils not available")
|
||||
|
||||
from axolotl.monkeypatch.models.pixtral.modeling_flash_attention_utils import (
|
||||
apply_patch_is_packed_sequence,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_is_packed_sequence = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = apply_patch_is_packed_sequence()
|
||||
|
||||
# Verify patch was applied
|
||||
assert (
|
||||
modeling_flash_attention_utils._is_packed_sequence
|
||||
!= original_is_packed_sequence
|
||||
), "_is_packed_sequence was not patched"
|
||||
|
||||
# Test the patched function with 1D position_ids
|
||||
patched_fn = modeling_flash_attention_utils._is_packed_sequence
|
||||
|
||||
# Test 1D position_ids 1 sequence
|
||||
position_ids_1d = torch.tensor([0, 1, 2, 3])
|
||||
result = patched_fn(position_ids_1d, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "1D sequential position_ids should not be packed"
|
||||
|
||||
# Test 1D packed 2 sequences
|
||||
position_ids_1d_packed = torch.tensor([0, 1, 2, 0, 1, 2])
|
||||
result = patched_fn(position_ids_1d_packed, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is True, "1D packed position_ids should be detected as packed"
|
||||
|
||||
# Test 2D packed 2 sequences
|
||||
position_ids_2d_packed = torch.tensor([[0, 1, 2, 3, 0, 1]])
|
||||
result = patched_fn(position_ids_2d_packed, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is True, "2D packed position_ids should be detected as packed"
|
||||
|
||||
# Test 2D 1 sequence
|
||||
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5]])
|
||||
result = patched_fn(position_ids_2d_normal, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "2D sequential position_ids should not be packed"
|
||||
|
||||
# Test 2D batch size 2
|
||||
position_ids_2d_normal = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8]])
|
||||
result = patched_fn(position_ids_2d_normal, batch_size=2)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "2D position_ids batch 2 should not be packed"
|
||||
|
||||
# Test None case
|
||||
result = patched_fn(None, batch_size=1)
|
||||
assert isinstance(result, bool), "Function should return a boolean"
|
||||
assert result is False, "None position_ids should return False"
|
||||
|
||||
# Test unpatch function
|
||||
unpatch_fn()
|
||||
assert (
|
||||
modeling_flash_attention_utils._is_packed_sequence
|
||||
== original_is_packed_sequence
|
||||
), "unpatch function did not restore original method"
|
||||
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
111
tests/monkeypatch/test_qwen3_next_modeling_patch.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Integration tests for Qwen3 Next modeling patches."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip entire module if qwen3_next not available
|
||||
qwen3_next = pytest.importorskip("transformers.models.qwen3_next.modeling_qwen3_next")
|
||||
|
||||
|
||||
class TestQwen3NextModelingPatchIntegration:
|
||||
"""Test Qwen3 Next modeling patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_decoder_layer_patch(self):
|
||||
"""Test that Qwen3Next decoder layer patch can be applied."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_decoder_layer,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = qwen3_next.Qwen3NextDecoderLayer.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_qwen3_next_decoder_layer()
|
||||
|
||||
# Verify patch was applied
|
||||
assert qwen3_next.Qwen3NextDecoderLayer.forward != original_forward, (
|
||||
"decoder layer forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(qwen3_next.Qwen3NextDecoderLayer.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
if unpatch_fn:
|
||||
unpatch_fn()
|
||||
assert qwen3_next.Qwen3NextDecoderLayer.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_gateddelta_layer_patch(self):
|
||||
"""Test that Qwen3Next GatedDeltaNet patch can be applied."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_gateddelta_layer,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = qwen3_next.Qwen3NextGatedDeltaNet.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_qwen3_next_gateddelta_layer()
|
||||
|
||||
# Verify patch was applied
|
||||
assert qwen3_next.Qwen3NextGatedDeltaNet.forward != original_forward, (
|
||||
"GatedDeltaNet forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(qwen3_next.Qwen3NextGatedDeltaNet.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
if unpatch_fn:
|
||||
unpatch_fn()
|
||||
assert qwen3_next.Qwen3NextGatedDeltaNet.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_imports_patch(self):
|
||||
"""Test that Qwen3Next imports patch can be applied without errors."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_imports,
|
||||
)
|
||||
|
||||
# Apply patch - should not raise any exceptions even if modules unavailable
|
||||
unpatch_fn = patch_qwen3_next_imports()
|
||||
|
||||
# Test that unpatch function is returned (or None if skipped)
|
||||
assert unpatch_fn is None or callable(unpatch_fn), (
|
||||
"patch_qwen3_next_imports should return None or callable unpatch function"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_qwen3_next_modeling_packing_patch(self):
|
||||
"""Test that all Qwen3Next modeling patches can be applied together."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import (
|
||||
patch_qwen3_next_modeling_packing,
|
||||
)
|
||||
|
||||
# This should not raise any exceptions
|
||||
patch_qwen3_next_modeling_packing()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_cu_seqlens_utility():
|
||||
"""Test the get_cu_seqlens utility function."""
|
||||
from axolotl.monkeypatch.models.qwen3_next.modeling import get_cu_seqlens
|
||||
|
||||
# Test with simple position_ids
|
||||
position_ids = torch.tensor([[0, 1, 2, 0, 1]])
|
||||
cu_seqlens = get_cu_seqlens(position_ids)
|
||||
assert cu_seqlens.dtype == torch.int32, "Should be int32 dtype"
|
||||
|
||||
# Should return tensor with start positions and total length
|
||||
expected = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
assert torch.equal(cu_seqlens, expected), f"Expected {expected}, got {cu_seqlens}"
|
||||
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
26
tests/monkeypatch/test_trainer_accelerator_args.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Unit tests for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.trainer_accelerator_args import (
|
||||
check_create_accelerate_code_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerAcceleratorArgs(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer accelerator args monkeypatch
|
||||
"""
|
||||
|
||||
def test_check_create_accelerate_code_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable.
|
||||
This will fail if the patched code changes upstream.
|
||||
"""
|
||||
assert check_create_accelerate_code_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
66
tests/monkeypatch/test_trainer_context_parallel_patch.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for the HF Trainer context parallel patch."""
|
||||
|
||||
import pytest
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_context_parallel import (
|
||||
GUARD_PATTERN,
|
||||
PATCHED_GUARD,
|
||||
patch_prepare_context_parallel_inputs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def restore_trainer_prepare_method():
|
||||
"""Ensure Trainer._prepare_context_parallel_inputs is restored after a test."""
|
||||
original_method = getattr(
|
||||
Trainer,
|
||||
"_original_prepare_context_parallel_inputs",
|
||||
Trainer._prepare_context_parallel_inputs,
|
||||
)
|
||||
patched_attr_present = hasattr(
|
||||
Trainer, "_axolotl_prepare_context_parallel_inputs_patched"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
Trainer._prepare_context_parallel_inputs = original_method
|
||||
if patched_attr_present:
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_source")
|
||||
|
||||
|
||||
def test_patch_attention_guard(restore_trainer_prepare_method):
|
||||
"""Patch should swap the guard to allow sdpa or flash attention."""
|
||||
# Ensure we start from the unpatched method
|
||||
if hasattr(Trainer, "_original_prepare_context_parallel_inputs"):
|
||||
Trainer._prepare_context_parallel_inputs = (
|
||||
Trainer._original_prepare_context_parallel_inputs
|
||||
)
|
||||
delattr(Trainer, "_original_prepare_context_parallel_inputs")
|
||||
if hasattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched"):
|
||||
delattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched")
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
|
||||
patched_method = Trainer._prepare_context_parallel_inputs
|
||||
assert patched_method is not None
|
||||
assert getattr(Trainer, "_axolotl_prepare_context_parallel_inputs_patched", False)
|
||||
|
||||
source = Trainer._axolotl_prepare_context_parallel_inputs_source
|
||||
assert GUARD_PATTERN not in source
|
||||
assert PATCHED_GUARD in source
|
||||
|
||||
|
||||
def test_patch_is_idempotent(restore_trainer_prepare_method):
|
||||
"""Calling the patch twice should leave the same patched function in place."""
|
||||
patch_prepare_context_parallel_inputs()
|
||||
first_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
patch_prepare_context_parallel_inputs()
|
||||
second_patched = Trainer._prepare_context_parallel_inputs
|
||||
|
||||
assert first_patched is second_patched
|
||||
26
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
26
tests/monkeypatch/test_trainer_loss_calc.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Unit tests for trainer loss calc monkeypatch."""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
|
||||
check_evaluation_loop_is_patchable,
|
||||
check_maybe_log_save_evaluate_is_patchable,
|
||||
)
|
||||
|
||||
|
||||
class TestTrainerLossCalc(unittest.TestCase):
|
||||
"""
|
||||
Unit test class for trainer loss calc monkeypatch
|
||||
"""
|
||||
|
||||
def test_trainer_loss_calc_is_patchable(self):
|
||||
"""
|
||||
Test that the upstream transformers code is still patchable. This will fail if
|
||||
the patched code changes upstream.
|
||||
"""
|
||||
assert check_evaluation_loop_is_patchable()
|
||||
assert check_maybe_log_save_evaluate_is_patchable()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
43
tests/monkeypatch/test_voxtral_modeling_patch.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Integration tests for Voxtral modeling patches."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestVoxtralModelingPatchIntegration:
|
||||
"""Test Voxtral modeling patch integration."""
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_voxtral_conditional_generation_patch(self):
|
||||
"""Test that Voxtral conditional generation patch can be applied."""
|
||||
try:
|
||||
from transformers.models.voxtral.modeling_voxtral import (
|
||||
VoxtralForConditionalGeneration,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("VoxtralForConditionalGeneration not available")
|
||||
|
||||
from axolotl.monkeypatch.models.voxtral.modeling import (
|
||||
patch_voxtral_conditional_generation_forward,
|
||||
)
|
||||
|
||||
# Store original method
|
||||
original_forward = VoxtralForConditionalGeneration.forward
|
||||
|
||||
# Apply patch and get unpatch function
|
||||
unpatch_fn = patch_voxtral_conditional_generation_forward()
|
||||
|
||||
# Verify patch was applied
|
||||
assert VoxtralForConditionalGeneration.forward != original_forward, (
|
||||
"forward method was not patched"
|
||||
)
|
||||
|
||||
# Verify the method is still callable
|
||||
assert callable(VoxtralForConditionalGeneration.forward), (
|
||||
"Patched method is not callable"
|
||||
)
|
||||
|
||||
# Test unpatch function
|
||||
unpatch_fn()
|
||||
assert VoxtralForConditionalGeneration.forward == original_forward, (
|
||||
"unpatch function did not restore original method"
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
# pylint: disable=too-many-lines
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import os
|
||||
@@ -49,7 +48,6 @@ class BaseValidation:
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module
|
||||
@@ -241,7 +239,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
def test_lr_as_float(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
@@ -303,7 +301,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -315,7 +313,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -327,7 +325,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
}
|
||||
@@ -339,7 +337,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
@@ -361,7 +359,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -373,7 +371,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -385,7 +383,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
@@ -692,7 +690,7 @@ class TestValidation(BaseValidation):
|
||||
"bf16": True,
|
||||
"capabilities": {"bf16": False},
|
||||
"env_capabilities": {
|
||||
"torch_version": "2.5.1",
|
||||
"torch_version": "2.6.0",
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1202,7 +1200,7 @@ class TestValidation(BaseValidation):
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
env_capabilities = {"torch_version": "2.6.0"}
|
||||
capabilities = {"bf16": False}
|
||||
_ = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
@@ -1244,7 +1242,7 @@ class TestTorchCompileValidation(BaseValidation):
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
env_capabilities = {"torch_version": "2.5.1"}
|
||||
env_capabilities = {"torch_version": "2.6.0"}
|
||||
capabilities = {"bf16": True}
|
||||
updated_cfg = validate_config(
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
@@ -1690,3 +1688,18 @@ class TestValidationMLflow(BaseValidation):
|
||||
assert new_cfg.use_mlflow is True
|
||||
|
||||
os.environ.pop("MLFLOW_EXPERIMENT_NAME", None)
|
||||
|
||||
|
||||
class TestDataloaderValidation(BaseValidation):
|
||||
"""
|
||||
tests for dataloader_* sane defaults
|
||||
"""
|
||||
|
||||
def test_dataloader_auto_defaults(self, minimal_cfg):
|
||||
cfg = minimal_cfg
|
||||
|
||||
new_cfg = validate_config(cfg, {"n_gpu": 8}, {"torch_version": "2.6.0"})
|
||||
|
||||
assert new_cfg.dataloader_num_workers == 8
|
||||
assert new_cfg.dataloader_pin_memory is True
|
||||
assert new_cfg.dataloader_prefetch_factor == 256
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user