isolating problematic test

This commit is contained in:
Dan Saunders
2024-12-18 03:30:35 +00:00
parent dda9b25994
commit 80ba0d8dd1
7 changed files with 85 additions and 71 deletions

View File

@@ -1,6 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os

View File

@@ -12,7 +12,6 @@ from axolotl.integrations.differential_transformer.differential_attention import
def patch_llama_attention_classes(): def patch_llama_attention_classes():
"""Patch transformers to support differential attention""" """Patch transformers to support differential attention"""
# Add our attention class to the registry # Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention

View File

@@ -843,7 +843,6 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
# self.model._attn_implementation_autoset = False
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

View File

@@ -0,0 +1,28 @@
"""Shared fixtures for differential transformer conversion tests."""
import pytest
@pytest.fixture()
def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
],
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"val_set_size": 0.1,
"micro_batch_size": 1,
"sequence_len": 2048,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}

View File

@@ -0,0 +1,53 @@
"""End-to-end tests for differential transformer conversion and evaluation."""
# pylint: disable=duplicate-code
from pathlib import Path
import yaml
from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer,
)
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)

View File

@@ -1,44 +1,18 @@
"""End-to-end tests for differential transformer conversion.""" """End-to-end tests for differential transformer conversion."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
# pylint: disable=duplicate-code
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import pytest import pytest
import yaml import yaml
from pytest import approx
from axolotl.cli import load_cfg from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_differential_transformer import ( from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer, convert_differential_transformer,
) )
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs from axolotl.common.cli import ConvertDiffTransformerCliArgs
@pytest.fixture()
def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
],
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"val_set_size": 0.1,
"micro_batch_size": 1,
"sequence_len": 2048,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
def test_conversion_cli_basic(tmp_path: Path, base_config): def test_conversion_cli_basic(tmp_path: Path, base_config):
@@ -132,42 +106,3 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists() assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists() assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists() assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)