plugin implementation

This commit is contained in:
Dan Saunders
2024-12-18 01:26:41 +00:00
parent d22e1136bc
commit ea07a7086e
13 changed files with 118 additions and 30 deletions

View File

@@ -6,12 +6,14 @@ from typing import Optional
import pytest
import yaml
from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.common.cli import ConvertDiffTransformerCliArgs
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
@pytest.fixture()
@@ -19,9 +21,12 @@ def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
],
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
@@ -103,7 +108,9 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"])
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
@@ -125,3 +132,42 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)