convert-differential-transformer test coverage
This commit is contained in:
0
tests/cli/integrations/__init__.py
Normal file
0
tests/cli/integrations/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for convert-differential-transformer CLI command."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_cli_validation(cli_runner):
|
||||
"""Test CLI validation for a command.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
|
||||
def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str):
|
||||
"""Test basic execution.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", str(config_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
127
tests/e2e/integrations/test_convert_differential_transformer.py
Normal file
127
tests/e2e/integrations/test_convert_differential_transformer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""End-to-end tests for differential transformer conversion."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||
convert_differential_transformer,
|
||||
)
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def base_config():
|
||||
"""Basic config for testing."""
|
||||
return {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-4,
|
||||
"val_set_size": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sequence_len": 2048,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
# Load config the same way do_cli does
|
||||
cfg = load_cfg(str(config_path))
|
||||
|
||||
# Create CLI args
|
||||
cli_args = ConvertDiffTransformerCliArgs()
|
||||
|
||||
# Call convert_differential_transformer directly
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
# Load config the same way do_cli does
|
||||
cfg = load_cfg(str(config_path))
|
||||
|
||||
# Create CLI args
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
|
||||
# Call convert_differential_transformer directly
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
assert not debug_info["match_expected"]
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"])
|
||||
def test_conversion_cli_repoduce_attentions(
|
||||
tmp_path: Path, base_config, attention: Optional[str]
|
||||
):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
Reference in New Issue
Block a user