isolating problematic test
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from axolotl.integrations.differential_transformer.differential_attention import
|
|||||||
|
|
||||||
def patch_llama_attention_classes():
|
def patch_llama_attention_classes():
|
||||||
"""Patch transformers to support differential attention"""
|
"""Patch transformers to support differential attention"""
|
||||||
|
|
||||||
# Add our attention class to the registry
|
# Add our attention class to the registry
|
||||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||||
|
|||||||
@@ -843,7 +843,6 @@ class ModelLoader:
|
|||||||
if self.cfg.is_multimodal:
|
if self.cfg.is_multimodal:
|
||||||
self.model_config.text_config = self.text_model_config
|
self.model_config.text_config = self.text_model_config
|
||||||
|
|
||||||
# self.model._attn_implementation_autoset = False
|
|
||||||
self.model = self.AutoModelLoader.from_pretrained(
|
self.model = self.AutoModelLoader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""Shared fixtures for differential transformer conversion tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def base_config():
|
||||||
|
"""Basic config for testing."""
|
||||||
|
return {
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
|
||||||
|
],
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "axolotl-ai-co/alpaca_100_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""End-to-end tests for differential transformer conversion and evaluation."""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pytest import approx
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.cli.evaluate import do_evaluate
|
||||||
|
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||||
|
convert_differential_transformer,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
|
)
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is True
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
eval_cfg = load_cfg(str(output_dir))
|
||||||
|
eval_cli_args = EvaluateCliArgs()
|
||||||
|
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
|
||||||
|
|
||||||
|
assert list(all_metrics.keys()) == [
|
||||||
|
"train_loss",
|
||||||
|
"train_model_preparation_time",
|
||||||
|
"train_runtime",
|
||||||
|
"train_samples_per_second",
|
||||||
|
"train_steps_per_second",
|
||||||
|
"eval_loss",
|
||||||
|
"eval_model_preparation_time",
|
||||||
|
"eval_runtime",
|
||||||
|
"eval_samples_per_second",
|
||||||
|
"eval_steps_per_second",
|
||||||
|
]
|
||||||
|
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
|
||||||
|
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)
|
||||||
@@ -1,44 +1,18 @@
|
|||||||
"""End-to-end tests for differential transformer conversion."""
|
"""End-to-end tests for differential transformer conversion."""
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from pytest import approx
|
|
||||||
|
|
||||||
from axolotl.cli import load_cfg
|
from axolotl.cli import load_cfg
|
||||||
from axolotl.cli.evaluate import do_evaluate
|
|
||||||
from axolotl.cli.integrations.convert_differential_transformer import (
|
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||||
convert_differential_transformer,
|
convert_differential_transformer,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def base_config():
|
|
||||||
"""Basic config for testing."""
|
|
||||||
return {
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
|
|
||||||
],
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "axolotl-ai-co/alpaca_100_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-4,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"sequence_len": 2048,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||||
@@ -132,42 +106,3 @@ def test_conversion_cli_repoduce_attentions(
|
|||||||
assert (output_dir / "model.safetensors").exists()
|
assert (output_dir / "model.safetensors").exists()
|
||||||
assert (output_dir / "config.json").exists()
|
assert (output_dir / "config.json").exists()
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
|
||||||
output_dir = tmp_path / "converted"
|
|
||||||
base_config["output_dir"] = str(output_dir)
|
|
||||||
|
|
||||||
config_path = tmp_path / "config.yml"
|
|
||||||
with open(config_path, "w", encoding="utf-8") as file:
|
|
||||||
yaml.dump(base_config, file)
|
|
||||||
|
|
||||||
cfg = load_cfg(str(config_path))
|
|
||||||
cli_args = ConvertDiffTransformerCliArgs(
|
|
||||||
debug=True, zero_init=True, sublayer_norm=False
|
|
||||||
)
|
|
||||||
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
|
||||||
|
|
||||||
assert debug_info["generations_match"] is True
|
|
||||||
assert (output_dir / "model.safetensors").exists()
|
|
||||||
assert (output_dir / "config.json").exists()
|
|
||||||
assert (output_dir / "axolotl_config.yml").exists()
|
|
||||||
|
|
||||||
eval_cfg = load_cfg(str(output_dir))
|
|
||||||
eval_cli_args = EvaluateCliArgs()
|
|
||||||
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
|
|
||||||
|
|
||||||
assert list(all_metrics.keys()) == [
|
|
||||||
"train_loss",
|
|
||||||
"train_model_preparation_time",
|
|
||||||
"train_runtime",
|
|
||||||
"train_samples_per_second",
|
|
||||||
"train_steps_per_second",
|
|
||||||
"eval_loss",
|
|
||||||
"eval_model_preparation_time",
|
|
||||||
"eval_runtime",
|
|
||||||
"eval_samples_per_second",
|
|
||||||
"eval_steps_per_second",
|
|
||||||
]
|
|
||||||
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
|
|
||||||
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)
|
|
||||||
Reference in New Issue
Block a user