moving tests around for flash_attn install
This commit is contained in:
@@ -84,9 +84,9 @@ class LlamaDifferentialAttention(nn.Module):
|
|||||||
|
|
||||||
if config.split_heads:
|
if config.split_heads:
|
||||||
# Split heads mode
|
# Split heads mode
|
||||||
assert (
|
# assert (
|
||||||
self.base_num_heads % 2 == 0
|
# self.base_num_heads % 2 == 0
|
||||||
), "Number of heads must be even for splitting"
|
# ), "Number of heads must be even for splitting"
|
||||||
self.heads_per_component = self.base_num_heads // 2
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
|
|
||||||
# Single projections
|
# Single projections
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Shared pytest fixtures for cli module."""
|
"""Shared pytest fixtures for cli module."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Tests for convert-differential-transformer CLI command."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_validation(cli_runner):
|
|
||||||
"""Test CLI validation for a command.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cli_runner: CLI runner fixture
|
|
||||||
"""
|
|
||||||
# Test missing config file
|
|
||||||
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
|
||||||
|
|
||||||
# Test non-existent config file
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
|
||||||
)
|
|
||||||
assert result.exit_code != 0
|
|
||||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str):
|
|
||||||
"""Test basic execution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cli_runner: CLI runner fixture
|
|
||||||
tmp_path: Temporary path fixture
|
|
||||||
valid_test_config: Valid config fixture
|
|
||||||
"""
|
|
||||||
config_path = tmp_path / "config.yml"
|
|
||||||
config_path.write_text(valid_test_config)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
|
||||||
) as mock_do_cli:
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli, ["convert-differential-transformer", str(config_path)]
|
|
||||||
)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
mock_do_cli.assert_called_once()
|
|
||||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI fetch command."""
|
"""pytest tests for axolotl CLI fetch command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import fetch
|
from axolotl.cli.main import fetch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI inference command."""
|
"""pytest tests for axolotl CLI inference command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""General pytest tests for axolotl.cli.main interface."""
|
"""General pytest tests for axolotl.cli.main interface."""
|
||||||
|
|
||||||
from axolotl.cli.main import build_command, cli
|
from axolotl.cli.main import build_command, cli
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI merge_lora command."""
|
"""pytest tests for axolotl CLI merge_lora command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI preprocess command."""
|
"""pytest tests for axolotl CLI preprocess command."""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""pytest tests for axolotl CLI shard command."""
|
"""pytest tests for axolotl CLI shard command."""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI --version"""
|
"""pytest tests for axolotl CLI --version"""
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""pytest tests for axolotl CLI utils."""
|
"""pytest tests for axolotl CLI utils."""
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Shared fixtures for differential transformer conversion tests."""
|
"""Shared fixtures for differential transformer conversion tests."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@@ -26,3 +27,8 @@ def base_config():
|
|||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_runner():
|
||||||
|
return CliRunner()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
@@ -12,9 +13,41 @@ from axolotl.cli import load_cfg
|
|||||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||||
convert_differential_transformer,
|
convert_differential_transformer,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.main import cli
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_validation(cli_runner):
|
||||||
|
# Test missing config file
|
||||||
|
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||||
|
|
||||||
|
# Test non-existent config file
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||||
|
) as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["convert-differential-transformer", str(config_path)]
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
|
||||||
|
|
||||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||||
output_dir = tmp_path / "converted"
|
output_dir = tmp_path / "converted"
|
||||||
base_config["output_dir"] = str(output_dir)
|
base_config["output_dir"] = str(output_dir)
|
||||||
@@ -113,7 +146,6 @@ def test_conversion_cli_repoduce_attentions(
|
|||||||
)
|
)
|
||||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||||
output_dir = tmp_path / "converted"
|
output_dir = tmp_path / "converted"
|
||||||
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
|
||||||
base_config["output_dir"] = str(output_dir)
|
base_config["output_dir"] = str(output_dir)
|
||||||
base_config[attention] = True
|
base_config[attention] = True
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user