convert-differential-transformer test coverage
This commit is contained in:
@@ -1,185 +0,0 @@
|
|||||||
"""CLI to convert a transformers model's attns to diff attns."""
|
|
||||||
import logging
|
|
||||||
import warnings
|
|
||||||
from pathlib import Path
|
|
||||||
from time import time
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import fire
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from colorama import Fore
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
|
||||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
|
||||||
from axolotl.integrations.differential_transformer.convert import (
|
|
||||||
convert_to_diff_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.convert_attention")
|
|
||||||
|
|
||||||
|
|
||||||
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
|
||||||
"""Run test inference and return generation time"""
|
|
||||||
try:
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
|
||||||
inputs = {
|
|
||||||
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
start = time()
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = model.generate(
|
|
||||||
**inputs,
|
|
||||||
max_new_tokens=20,
|
|
||||||
num_beams=1,
|
|
||||||
do_sample=False,
|
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
|
||||||
use_cache=False,
|
|
||||||
)
|
|
||||||
elapsed = time() - start
|
|
||||||
|
|
||||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
||||||
LOG.info("Prompt: %s", prompt)
|
|
||||||
LOG.info("Generated: %s", generated_text)
|
|
||||||
LOG.info("Generation time: %.2fs", elapsed)
|
|
||||||
|
|
||||||
return elapsed, generated_text
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.error("Inference failed: %s", str(exc))
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
|
||||||
try:
|
|
||||||
# Load model and tokenizer
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.simplefilter("ignore")
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
|
||||||
|
|
||||||
# Log original model info
|
|
||||||
LOG.info(
|
|
||||||
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
|
|
||||||
model.config.hidden_size,
|
|
||||||
model.config.num_attention_heads,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test original model
|
|
||||||
if cli_args.debug:
|
|
||||||
LOG.info("Testing original model...")
|
|
||||||
orig_time, orig_text = test_inference(model, tokenizer)
|
|
||||||
|
|
||||||
# Convert attention
|
|
||||||
LOG.info("Converting to differential attention...")
|
|
||||||
try:
|
|
||||||
model = convert_to_diff_attention(
|
|
||||||
model=model,
|
|
||||||
zero_init=cli_args.zero_init,
|
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
|
||||||
)
|
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Test converted model
|
|
||||||
if cli_args.debug:
|
|
||||||
LOG.info("Testing converted model...")
|
|
||||||
conv_time, conv_text = test_inference(model, tokenizer)
|
|
||||||
|
|
||||||
# Save if requested
|
|
||||||
if cfg.output_dir:
|
|
||||||
# Save model and tokenizer
|
|
||||||
LOG.info("Saving converted model to %s", cfg.output_dir)
|
|
||||||
model.save_pretrained(cfg.output_dir)
|
|
||||||
tokenizer.save_pretrained(cfg.output_dir)
|
|
||||||
|
|
||||||
# Modify config to reflect new path / differential attention
|
|
||||||
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
|
|
||||||
LOG.info("Saving updated config to %s", output_config_path)
|
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8") as file:
|
|
||||||
data = yaml.safe_load(file) or {}
|
|
||||||
|
|
||||||
data["base_model"] = cfg.output_dir
|
|
||||||
data["diff_attention"] = True
|
|
||||||
|
|
||||||
with open(output_config_path, "w", encoding="utf-8") as file:
|
|
||||||
yaml.dump(data, file)
|
|
||||||
else:
|
|
||||||
LOG.info("Not saving converted model to disk")
|
|
||||||
LOG.info("Pass --output-dir path/to/save to save model")
|
|
||||||
|
|
||||||
if cli_args.debug:
|
|
||||||
LOG.info(
|
|
||||||
Fore.GREEN
|
|
||||||
+ "Conversion successful!\n"
|
|
||||||
+ f"Original generation time: {orig_time:.2f}s\n"
|
|
||||||
+ f"Converted generation time: {conv_time:.2f}s"
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
|
|
||||||
if orig_text == conv_text:
|
|
||||||
LOG.info(
|
|
||||||
Fore.GREEN
|
|
||||||
+ "Generations match!\n"
|
|
||||||
+ "Model generation:\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
+ f"{orig_text}\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message = (
|
|
||||||
"Generations do not match.\n"
|
|
||||||
+ "Original generation:\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
+ f"{orig_text}\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
+ "Converted generation:\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
+ f"{conv_text}\n"
|
|
||||||
+ "*" * 50
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cli_args.zero_init and not cli_args.sublayer_norm:
|
|
||||||
LOG.info(Fore.RED + message + Fore.RESET)
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
Fore.YELLOW
|
|
||||||
+ message
|
|
||||||
+ "However, this is expected since --zero-init"
|
|
||||||
+ " and --no-sublayer-norm were not passed."
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
||||||
print_axolotl_text_art()
|
|
||||||
|
|
||||||
cfg = load_cfg(config, **kwargs)
|
|
||||||
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
|
||||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
|
||||||
|
|
||||||
convert_diff_transformer(cfg, cli_args, config)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
load_dotenv()
|
|
||||||
fire.Fire(do_cli)
|
|
||||||
190
src/axolotl/cli/integrations/convert_differential_transformer.py
Normal file
190
src/axolotl/cli/integrations/convert_differential_transformer.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""CLI to convert a transformers model's attns to diff attns."""
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from time import time
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
from colorama import Fore
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
|
||||||
|
from axolotl.integrations.differential_transformer.convert import (
|
||||||
|
convert_to_diff_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||||
|
"""Run test inference and return generation time"""
|
||||||
|
try:
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
inputs = {
|
||||||
|
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=20,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
elapsed = time() - start
|
||||||
|
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
LOG.info("Prompt: %s", prompt)
|
||||||
|
LOG.info("Generated: %s", generated_text)
|
||||||
|
LOG.info("Generation time: %.2fs", elapsed)
|
||||||
|
|
||||||
|
return elapsed, generated_text
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.error("Inference failed: %s", str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def convert_differential_transformer(cfg, cli_args, config_path):
|
||||||
|
debug_info = {}
|
||||||
|
|
||||||
|
# Load model and tokenizer
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
# Log original model info
|
||||||
|
LOG.info(
|
||||||
|
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test original model
|
||||||
|
if cli_args.debug:
|
||||||
|
LOG.info("Testing original model...")
|
||||||
|
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
|
||||||
|
model, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert attention
|
||||||
|
LOG.info("Converting to differential attention...")
|
||||||
|
try:
|
||||||
|
model = convert_to_diff_attention(
|
||||||
|
model=model,
|
||||||
|
zero_init=cli_args.zero_init,
|
||||||
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
|
)
|
||||||
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Test converted model
|
||||||
|
if cli_args.debug:
|
||||||
|
LOG.info("Testing converted model...")
|
||||||
|
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
|
||||||
|
model, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save if requested
|
||||||
|
if cfg.output_dir:
|
||||||
|
# Save model and tokenizer
|
||||||
|
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||||
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
tokenizer.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
# Modify config to reflect new path / differential attention
|
||||||
|
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
|
||||||
|
LOG.info("Saving updated config to %s", output_config_path)
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
|
data = yaml.safe_load(file) or {}
|
||||||
|
|
||||||
|
data["base_model"] = cfg.output_dir
|
||||||
|
data["diff_attention"] = True
|
||||||
|
|
||||||
|
with open(output_config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(data, file)
|
||||||
|
else:
|
||||||
|
LOG.info("Not saving converted model to disk")
|
||||||
|
LOG.info("Pass --output-dir path/to/save to save model")
|
||||||
|
|
||||||
|
if cli_args.debug:
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ "Conversion successful!\n"
|
||||||
|
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
|
||||||
|
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
if debug_info["orig_text"] == debug_info["conv_text"]:
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ "Generations match!\n"
|
||||||
|
+ "Model generation:\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
+ f"{debug_info['orig_text']}\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
debug_info["generations_match"] = True
|
||||||
|
else:
|
||||||
|
message = (
|
||||||
|
"Generations do not match.\n"
|
||||||
|
+ "Original generation:\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
+ f"{debug_info['orig_text']}\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
+ "Converted generation:\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
+ f"{debug_info['conv_text']}\n"
|
||||||
|
+ "*" * 50
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
debug_info["generations_match"] = False
|
||||||
|
|
||||||
|
if cli_args.zero_init and not cli_args.sublayer_norm:
|
||||||
|
LOG.info(Fore.RED + message + Fore.RESET)
|
||||||
|
debug_info["match_expected"] = True
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
Fore.YELLOW
|
||||||
|
+ message
|
||||||
|
+ "However, this is expected since --zero-init"
|
||||||
|
+ " and --no-sublayer-norm were not passed."
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
debug_info["match_expected"] = False
|
||||||
|
|
||||||
|
return model, debug_info
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
|
||||||
|
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
|
convert_differential_transformer(cfg, cli_args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -249,11 +249,11 @@ def merge_lora(
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
def convert_diff_transformer(config: str, **kwargs):
|
def convert_differential_transformer(config: str, **kwargs):
|
||||||
"""Convert model attention layers to differential attention layers."""
|
"""Convert model attention layers to differential attention layers."""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
from axolotl.cli.integrations.convert_differential_transformer import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConvertDiffTransformerCliArgs:
|
class ConvertDiffTransformerCliArgs:
|
||||||
"""
|
"""
|
||||||
dataclass with arguments for convert-diff-transformer CLI
|
dataclass with arguments for convert-differential-transformer CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
debug: bool = field(default=False)
|
debug: bool = field(default=False)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from .differential_attention import (
|
|||||||
LlamaDifferentialSdpaAttention,
|
LlamaDifferentialSdpaAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ATTENTION_MAPPING = {
|
ATTENTION_MAPPING = {
|
||||||
|
|||||||
0
tests/cli/integrations/__init__.py
Normal file
0
tests/cli/integrations/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for convert-differential-transformer CLI command."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_validation(cli_runner):
|
||||||
|
"""Test CLI validation for a command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cli_runner: CLI runner fixture
|
||||||
|
"""
|
||||||
|
# Test missing config file
|
||||||
|
result = cli_runner.invoke(cli, ["convert-differential-transformer"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||||
|
|
||||||
|
# Test non-existent config file
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["convert-differential-transformer", "nonexistent.yml"]
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str):
|
||||||
|
"""Test basic execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cli_runner: CLI runner fixture
|
||||||
|
tmp_path: Temporary path fixture
|
||||||
|
valid_test_config: Valid config fixture
|
||||||
|
"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.cli.integrations.convert_differential_transformer.do_cli"
|
||||||
|
) as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["convert-differential-transformer", str(config_path)]
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
127
tests/e2e/integrations/test_convert_differential_transformer.py
Normal file
127
tests/e2e/integrations/test_convert_differential_transformer.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""End-to-end tests for differential transformer conversion."""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.cli.integrations.convert_differential_transformer import (
|
||||||
|
convert_differential_transformer,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def base_config():
|
||||||
|
"""Basic config for testing."""
|
||||||
|
return {
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
# Load config the same way do_cli does
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
|
||||||
|
# Create CLI args
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs()
|
||||||
|
|
||||||
|
# Call convert_differential_transformer directly
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert not debug_info
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
# Load config the same way do_cli does
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
|
||||||
|
# Create CLI args
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||||
|
|
||||||
|
# Call convert_differential_transformer directly
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert not debug_info["generations_match"]
|
||||||
|
assert not debug_info["match_expected"]
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
|
)
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is True
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"])
|
||||||
|
def test_conversion_cli_repoduce_attentions(
|
||||||
|
tmp_path: Path, base_config, attention: Optional[str]
|
||||||
|
):
|
||||||
|
output_dir = tmp_path / "converted"
|
||||||
|
base_config["output_dir"] = str(output_dir)
|
||||||
|
base_config[attention] = True
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as file:
|
||||||
|
yaml.dump(base_config, file)
|
||||||
|
|
||||||
|
cfg = load_cfg(str(config_path))
|
||||||
|
cli_args = ConvertDiffTransformerCliArgs(
|
||||||
|
debug=True, zero_init=True, sublayer_norm=False
|
||||||
|
)
|
||||||
|
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
|
||||||
|
|
||||||
|
assert debug_info["generations_match"] is True
|
||||||
|
assert (output_dir / "model.safetensors").exists()
|
||||||
|
assert (output_dir / "config.json").exists()
|
||||||
|
assert (output_dir / "axolotl_config.yml").exists()
|
||||||
Reference in New Issue
Block a user