moving tests around for flash_attn install

This commit is contained in:
Dan Saunders
2024-12-18 19:36:23 +00:00
parent d4e29e5b67
commit 544f2a8a27
15 changed files with 52 additions and 52 deletions

View File

@@ -1,6 +1,7 @@
"""Shared fixtures for differential transformer conversion tests."""
import pytest
from click.testing import CliRunner
@pytest.fixture()
@@ -26,3 +27,8 @@ def base_config():
"pad_token": "<|endoftext|>",
},
}
@pytest.fixture
def cli_runner():
return CliRunner()

View File

@@ -4,6 +4,7 @@
from pathlib import Path
from typing import Optional
from unittest.mock import patch
import pytest
import yaml
@@ -12,9 +13,41 @@ from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(
cli, ["convert-differential-transformer", "nonexistent.yml"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-differential-transformer", str(config_path)]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
@@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions(
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True