moving tests around for flash_attn install
This commit is contained in:
@@ -84,9 +84,9 @@ class LlamaDifferentialAttention(nn.Module):
|
||||
|
||||
if config.split_heads:
|
||||
# Split heads mode
|
||||
assert (
|
||||
self.base_num_heads % 2 == 0
|
||||
), "Number of heads must be even for splitting"
|
||||
# assert (
|
||||
# self.base_num_heads % 2 == 0
|
||||
# ), "Number of heads must be even for splitting"
|
||||
self.heads_per_component = self.base_num_heads // 2
|
||||
|
||||
# Single projections
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Shared pytest fixtures for cli module."""
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Tests for convert-differential-transformer CLI command."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
def test_cli_validation(cli_runner):
|
||||
"""Test CLI validation for a command.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
"""
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
|
||||
def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str):
|
||||
"""Test basic execution.
|
||||
|
||||
Args:
|
||||
cli_runner: CLI runner fixture
|
||||
tmp_path: Temporary path fixture
|
||||
valid_test_config: Valid config fixture
|
||||
"""
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", str(config_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pytest tests for axolotl CLI fetch command."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import fetch
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""General pytest tests for axolotl.cli.main interface."""
|
||||
|
||||
from axolotl.cli.main import build_command, cli
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pytest tests for axolotl CLI merge_lora command."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pytest tests for axolotl CLI preprocess command."""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""pytest tests for axolotl CLI shard command."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""pytest tests for axolotl CLI --version"""
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""pytest tests for axolotl CLI utils."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Shared fixtures for differential transformer conversion tests."""
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -26,3 +27,8 @@ def base_config():
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -12,9 +13,41 @@ from axolotl.cli import load_cfg
|
||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||
convert_differential_transformer,
|
||||
)
|
||||
from axolotl.cli.main import cli
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
|
||||
def test_cli_validation(cli_runner):
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||
)
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
|
||||
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(
|
||||
cli, ["convert-differential-transformer", str(config_path)]
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
|
||||
|
||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
@@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions(
|
||||
)
|
||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user