moving out all diff attn code to plugin repo
This commit is contained in:
@@ -1,24 +1,20 @@
|
||||
"""CLI definition for various axolotl commands."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import subprocess # nosec B404
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
import axolotl
|
||||
from axolotl.cli.plugins import setup_plugin_commands
|
||||
from axolotl.cli.utils import (
|
||||
add_options_from_config,
|
||||
add_options_from_dataclass,
|
||||
build_command,
|
||||
fetch_from_github,
|
||||
)
|
||||
from axolotl.common.cli import (
|
||||
ConvertDiffTransformerCliArgs,
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
TrainerCliArgs,
|
||||
)
|
||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
@@ -249,24 +245,6 @@ def merge_lora(
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
try:
|
||||
from axolotl_diff_transformer.convert_diff_transformer import do_cli
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"axolotl-diff-transformer not found, please install it: https://github.com/axolotl-ai-cloud/diff-transformer"
|
||||
) from exc
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
@@ -281,6 +259,9 @@ def fetch(directory: str, dest: Optional[str]):
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
setup_plugin_commands(cli)
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
36
src/axolotl/cli/plugins.py
Normal file
36
src/axolotl/cli/plugins.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Module for adding click CLI commands from axolotl plugins."""
|
||||
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_plugin_commands(cli: click.core.Group) -> None:
|
||||
"""
|
||||
Setup CLI commands for available plugins.
|
||||
|
||||
Args:
|
||||
cli: Click CLI object to add plugin CLI options to.
|
||||
"""
|
||||
try:
|
||||
from axolotl_diff_transformer.convert_diff_transformer import do_cli
|
||||
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
except ImportError as exc:
|
||||
LOG.debug("axolotl-diff-transformer not found: %s", exc)
|
||||
@@ -4,13 +4,19 @@ shared module for cli specific things
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
|
||||
except: # noqa: E722 # pylint: disable=bare-except # nosec B110
|
||||
pass
|
||||
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -48,21 +54,10 @@ class EvaluateCliArgs:
|
||||
debug_num_examples: int = field(default=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertDiffTransformerCliArgs:
|
||||
"""dataclass with arguments for convert-diff-transformer CLI"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
split_heads: bool = field(default=False)
|
||||
mirror_weights: bool = field(default=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
|
||||
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"],
|
||||
):
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# Differential Transformer
|
||||
|
||||
### Installation
|
||||
|
||||
```shell
|
||||
pip install git+https://github.com/axolotl-ai-cloud/diff-transformer.git
|
||||
```
|
||||
|
||||
Editable:
|
||||
|
||||
```shell
|
||||
git clone git@github.com:axolotl-ai-cloud/diff-transformer.git
|
||||
cd diff-transformer
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||
|
||||
diff_attention: true
|
||||
```
|
||||
|
||||
Additional, optional arguments include:
|
||||
|
||||
```yaml
|
||||
# How often to log diffential attention-related metrics to wandb
|
||||
diff_attn_log_every: 100
|
||||
|
||||
# How many differential attention layers to monitor (strided from 0..k..num_layers)
|
||||
diff_attn_num_monitor_layers: 3
|
||||
|
||||
# How many steps to "warmup" the mixing parameter for the negative component of differential attention
|
||||
# Follows a linear warmup schedule from 0 to 1; if not specified, the mixing component is set to 1
|
||||
diff_attn_warmup_steps: 1000
|
||||
```
|
||||
@@ -1,69 +0,0 @@
|
||||
"""Definition of differential transformer plugin."""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.callbacks.diff_attn import (
|
||||
DifferentialAttentionMixingCallback,
|
||||
DifferentialAttentionMonitorCallback,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""Plugin for differential transformer integration with Axolotl."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Constructor for differential transformers plugin. Calls `register_diff_attn`
|
||||
to register differential attention custom modeling implementation to `AutoConfig`
|
||||
and `AutoModel`.
|
||||
"""
|
||||
from axolotl_diff_transformer.modeling.modeling_diff_attn import (
|
||||
register_diff_attn,
|
||||
)
|
||||
|
||||
register_diff_attn()
|
||||
|
||||
def get_input_args(self) -> str:
|
||||
"""Returns module path to diff transformer plugin args for `axolotl` config."""
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(
|
||||
self, cfg: DictDefault, model: PreTrainedModel
|
||||
) -> List[TrainerCallback]:
|
||||
"""
|
||||
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
|
||||
callbacks for the `axolotl` trainer if wandb usage is enabled.
|
||||
|
||||
Parameters:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: The loaded mfodel.
|
||||
|
||||
Returns:
|
||||
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
|
||||
"""
|
||||
callbacks = []
|
||||
if cfg.use_wandb:
|
||||
callbacks.append(
|
||||
DifferentialAttentionMonitorCallback(
|
||||
log_every=cfg.diff_attn_log_every,
|
||||
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
|
||||
warmup_steps=cfg.diff_attn_warmup_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.diff_attn_warmup_steps:
|
||||
callbacks.append(
|
||||
DifferentialAttentionMixingCallback(
|
||||
warmup_steps=cfg.diff_attn_warmup_steps
|
||||
)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Module for handling differential transfomer input arguments."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifferentialTransformerArgs(BaseModel):
|
||||
"""
|
||||
Input args for differential transformer.
|
||||
|
||||
Attributes:
|
||||
diff_attention: Whether to use differential attention layers.
|
||||
diff_attn_log_every: How often to log differential attention statistics.
|
||||
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
|
||||
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
|
||||
mixing weight from 0 to 1. If specified, will reach full mixing at this
|
||||
step. If `None`, negative attention has full weight from the start.
|
||||
"""
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
diff_attn_log_every: Optional[int] = 100
|
||||
diff_attn_num_monitor_layers: Optional[int] = 3
|
||||
diff_attn_warmup_steps: Optional[int] = None
|
||||
@@ -1,234 +0,0 @@
|
||||
"""
|
||||
Monitor and log differential attention components during training.
|
||||
|
||||
This module provides a callback for tracking the behavior of differential attention
|
||||
mechanisms, including lambda parameters and attention statistics.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from torch import nn
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
|
||||
class DifferentialAttentionMonitorCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to monitor differential attention components and lambda parameters.
|
||||
|
||||
This callback tracks attention statistics across all layers and provides detailed
|
||||
monitoring for a specified number of layers evenly spaced through the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
log_every: int = 250,
|
||||
num_monitor_layers: int = 3,
|
||||
warmup_steps: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the differential attention monitor.
|
||||
|
||||
Args:
|
||||
log_every: Number of steps between logging events.
|
||||
num_monitor_layers: Number of individual layers to monitor in detail.
|
||||
warmup_steps: Optional parameter for negative attention component warmup.
|
||||
"""
|
||||
self.log_every = log_every
|
||||
self.num_monitor_layers = num_monitor_layers
|
||||
self.warmup_steps = warmup_steps
|
||||
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Set up layer monitoring at the start of training.
|
||||
|
||||
Args:
|
||||
args: Training arguments.
|
||||
state: Training state.
|
||||
control: Training control object.
|
||||
model: The model being trained.
|
||||
**kwargs: Additional arguments passed by the trainer.
|
||||
"""
|
||||
if is_main_process():
|
||||
num_layers = len(model.model.layers)
|
||||
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
|
||||
|
||||
stride = (
|
||||
(num_layers - 1) / (self.num_monitor_layers - 1)
|
||||
if self.num_monitor_layers > 1
|
||||
else 0
|
||||
)
|
||||
self.monitor_layers = [
|
||||
round(i * stride) for i in range(self.num_monitor_layers)
|
||||
]
|
||||
print(f"Monitoring layers {self.monitor_layers} in detail")
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_step_end(
|
||||
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Log attention metrics at the end of each step.
|
||||
|
||||
Collects and logs:
|
||||
- Lambda parameter norms and values.
|
||||
- Attention statistics (mean and std).
|
||||
- Both per-layer and aggregate metrics.
|
||||
|
||||
Args:
|
||||
args: Training arguments.
|
||||
state: Training state.
|
||||
control: Training control object.
|
||||
model: The model being trained.
|
||||
**kwargs: Additional arguments passed by the trainer.
|
||||
"""
|
||||
if not is_main_process() or state.global_step % self.log_every != 0:
|
||||
return
|
||||
|
||||
assert self.monitor_layers is not None
|
||||
|
||||
# Aggregate stats across all layers
|
||||
all_q1_norms = []
|
||||
all_q2_norms = []
|
||||
all_k1_norms = []
|
||||
all_k2_norms = []
|
||||
all_lambda1 = []
|
||||
all_lambda2 = []
|
||||
all_lambda_full = []
|
||||
|
||||
metrics = {}
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
attn = layer.self_attn
|
||||
|
||||
# Collect stats for aggregation
|
||||
all_q1_norms.append(attn.lambda_q1.norm().item())
|
||||
all_q2_norms.append(attn.lambda_q2.norm().item())
|
||||
all_k1_norms.append(attn.lambda_k1.norm().item())
|
||||
all_k2_norms.append(attn.lambda_k2.norm().item())
|
||||
|
||||
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
|
||||
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
|
||||
all_lambda1.append(lambda1)
|
||||
all_lambda2.append(lambda2)
|
||||
all_lambda_full.append(attn.lambda_full)
|
||||
|
||||
# Log detailed metrics for monitored layers
|
||||
if layer_idx in self.monitor_layers:
|
||||
metrics.update(
|
||||
{
|
||||
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
|
||||
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
|
||||
f"layer_{layer_idx}/lambda1": lambda1,
|
||||
f"layer_{layer_idx}/lambda2": lambda2,
|
||||
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
|
||||
f"layer_{layer_idx}/lambda_full": lambda1
|
||||
- lambda2
|
||||
+ attn.lambda_init.item(),
|
||||
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
|
||||
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
|
||||
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
|
||||
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
|
||||
}
|
||||
)
|
||||
|
||||
# Add aggregate metrics
|
||||
metrics.update(
|
||||
{
|
||||
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
|
||||
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
|
||||
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
|
||||
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
|
||||
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
|
||||
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
|
||||
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
|
||||
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
|
||||
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
|
||||
.mean()
|
||||
.item(),
|
||||
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.warmup_steps:
|
||||
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
|
||||
|
||||
wandb.log(metrics, step=state.global_step)
|
||||
|
||||
|
||||
class DifferentialAttentionMixingCallback(TrainerCallback):
|
||||
"""
|
||||
Callback to gradually increase the weight of negative attention components during
|
||||
training.
|
||||
"""
|
||||
|
||||
def __init__(self, warmup_steps: int):
|
||||
"""
|
||||
Args:
|
||||
warmup_steps: Number of steps to linearly increase negative attention
|
||||
weight from 0 to 1. If `None`, negative attention has full weight from
|
||||
start.
|
||||
"""
|
||||
self.warmup_steps = warmup_steps
|
||||
self.diff_attention_layers: list[nn.Module] | None = None
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Cache the differential attention layers at the start of training."""
|
||||
if model is not None:
|
||||
# Get the actual model if it's wrapped
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
|
||||
# Cache all differential attention layers
|
||||
self.diff_attention_layers = [
|
||||
module for module in model.modules() if hasattr(module, "diff_attn_mix")
|
||||
]
|
||||
|
||||
def on_step_begin(
|
||||
self,
|
||||
args: Any,
|
||||
state: Any,
|
||||
control: Any,
|
||||
model: torch.nn.Module = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if self.diff_attention_layers and self.warmup_steps:
|
||||
# Calculate mixing parameter (0 to 1)
|
||||
mix = min(1.0, state.global_step / self.warmup_steps)
|
||||
|
||||
# Update cached layers
|
||||
for layer in self.diff_attention_layers:
|
||||
layer.diff_attn_mix = mix
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Shared fixtures for differential transformer conversion tests."""
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def base_config():
|
||||
"""Basic config for testing."""
|
||||
return {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "axolotl-ai-co/alpaca_100_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-4,
|
||||
"val_set_size": 0.1,
|
||||
"micro_batch_size": 1,
|
||||
"sequence_len": 2048,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cli_runner():
|
||||
return CliRunner()
|
||||
@@ -1,51 +0,0 @@
|
||||
"""End-to-end tests for differential transformer conversion and evaluation."""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pytest import approx
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.cli.evaluate import do_evaluate
|
||||
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
|
||||
|
||||
|
||||
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
eval_cfg = load_cfg(str(output_dir))
|
||||
eval_cli_args = EvaluateCliArgs()
|
||||
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
|
||||
|
||||
assert list(all_metrics.keys()) == [
|
||||
"train_loss",
|
||||
"train_model_preparation_time",
|
||||
"train_runtime",
|
||||
"train_samples_per_second",
|
||||
"train_steps_per_second",
|
||||
"eval_loss",
|
||||
"eval_model_preparation_time",
|
||||
"eval_runtime",
|
||||
"eval_samples_per_second",
|
||||
"eval_steps_per_second",
|
||||
]
|
||||
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
|
||||
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)
|
||||
@@ -1,150 +0,0 @@
|
||||
"""End-to-end tests for differential transformer conversion."""
|
||||
# pylint: disable=redefined-outer-name
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
|
||||
from axolotl.cli.main import cli
|
||||
from axolotl.common.cli import ConvertDiffTransformerCliArgs
|
||||
|
||||
|
||||
def test_cli_validation(cli_runner):
|
||||
# Test missing config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Missing argument 'CONFIG'." in result.output
|
||||
|
||||
# Test non-existent config file
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
|
||||
assert result.exit_code != 0
|
||||
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||
|
||||
|
||||
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
with patch(
|
||||
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
|
||||
) as mock_do_cli:
|
||||
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
|
||||
assert result.exit_code == 0
|
||||
|
||||
mock_do_cli.assert_called_once()
|
||||
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||
|
||||
|
||||
def test_conversion_cli_basic(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs()
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
assert not debug_info["match_expected"]
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
)
|
||||
def test_conversion_cli_repoduce_attentions(
|
||||
tmp_path: Path, base_config, attention: Optional[str]
|
||||
):
|
||||
output_dir = tmp_path / "converted"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(
|
||||
debug=True, zero_init=True, sublayer_norm=False
|
||||
)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is True
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
|
||||
)
|
||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||
output_dir = tmp_path / "converted"
|
||||
|
||||
# Smallest model with an even number of attention heads
|
||||
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
config_path = tmp_path / "config.yml"
|
||||
with open(config_path, "w", encoding="utf-8") as file:
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert debug_info["generations_match"] is False
|
||||
assert (output_dir / "model.safetensors").exists()
|
||||
assert (output_dir / "config.json").exists()
|
||||
assert (output_dir / "axolotl_config.yml").exists()
|
||||
Reference in New Issue
Block a user