moving tests around for flash_attn install

This commit is contained in:
Dan Saunders
2024-12-18 19:36:23 +00:00
parent d4e29e5b67
commit 544f2a8a27
15 changed files with 52 additions and 52 deletions

View File

@@ -84,9 +84,9 @@ class LlamaDifferentialAttention(nn.Module):
if config.split_heads: if config.split_heads:
# Split heads mode # Split heads mode
assert ( # assert (
self.base_num_heads % 2 == 0 # self.base_num_heads % 2 == 0
), "Number of heads must be even for splitting" # ), "Number of heads must be even for splitting"
self.heads_per_component = self.base_num_heads // 2 self.heads_per_component = self.base_num_heads // 2
# Single projections # Single projections

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module.""" """Shared pytest fixtures for cli module."""
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner

View File

@@ -1,48 +0,0 @@
"""Tests for convert-differential-transformer CLI command."""
from pathlib import Path
from unittest.mock import patch
from axolotl.cli.main import cli
def test_cli_validation(cli_runner):
"""Test CLI validation for a command.
Args:
cli_runner: CLI runner fixture
"""
# Test missing config file
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(
cli, ["convert-differential-transformer", "nonexistent.yml"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str):
"""Test basic execution.
Args:
cli_runner: CLI runner fixture
tmp_path: Temporary path fixture
valid_test_config: Valid config fixture
"""
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
with patch(
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-differential-transformer", str(config_path)]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command.""" """pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import fetch from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command.""" """pytest tests for axolotl CLI inference command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface.""" """General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli from axolotl.cli.main import build_command, cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command.""" """pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" """pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command.""" """pytest tests for axolotl CLI preprocess command."""
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command.""" """pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version""" """pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils.""" """pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import json import json
from unittest.mock import Mock, patch from unittest.mock import Mock, patch

View File

@@ -1,6 +1,7 @@
"""Shared fixtures for differential transformer conversion tests.""" """Shared fixtures for differential transformer conversion tests."""
import pytest import pytest
from click.testing import CliRunner
@pytest.fixture() @pytest.fixture()
@@ -26,3 +27,8 @@ def base_config():
"pad_token": "<|endoftext|>", "pad_token": "<|endoftext|>",
}, },
} }
@pytest.fixture
def cli_runner():
return CliRunner()

View File

@@ -4,6 +4,7 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from unittest.mock import patch
import pytest import pytest
import yaml import yaml
@@ -12,9 +13,41 @@ from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_differential_transformer import ( from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer, convert_differential_transformer,
) )
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(
cli, ["convert-differential-transformer", "nonexistent.yml"]
)
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(
cli, ["convert-differential-transformer", str(config_path)]
)
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(tmp_path: Path, base_config): def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted" output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir) base_config["output_dir"] = str(output_dir)
@@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions(
) )
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted" output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir) base_config["output_dir"] = str(output_dir)
base_config[attention] = True base_config[attention] = True