diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index da864582a..431f31a91 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -1,24 +1,20 @@ """CLI definition for various axolotl commands.""" - # pylint: disable=redefined-outer-name + import subprocess # nosec B404 from typing import Optional import click import axolotl +from axolotl.cli.plugins import setup_plugin_commands from axolotl.cli.utils import ( add_options_from_config, add_options_from_dataclass, build_command, fetch_from_github, ) -from axolotl.common.cli import ( - ConvertDiffTransformerCliArgs, - EvaluateCliArgs, - PreprocessCliArgs, - TrainerCliArgs, -) +from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -249,24 +245,6 @@ def merge_lora( do_cli(config=config, **kwargs) -@cli.command() -@click.argument("config", type=click.Path(exists=True, path_type=str)) -@add_options_from_dataclass(ConvertDiffTransformerCliArgs) -@add_options_from_config(AxolotlInputConfig) -def convert_diff_transformer(config: str, **kwargs): - """Convert model attention layers to differential attention layers.""" - kwargs = {k: v for k, v in kwargs.items() if v is not None} - - try: - from axolotl_diff_transformer.convert_diff_transformer import do_cli - except ImportError as exc: - raise ImportError( - "axolotl-diff-transformer not found, please install it: https://github.com/axolotl-ai-cloud/diff-transformer" - ) from exc - - do_cli(config=config, **kwargs) - - @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") @@ -281,6 +259,9 @@ def fetch(directory: str, dest: Optional[str]): fetch_from_github(f"{directory}/", dest) +setup_plugin_commands(cli) + + def main(): cli() diff --git a/src/axolotl/cli/plugins.py b/src/axolotl/cli/plugins.py new file mode 100644 index 000000000..7f0a4e6fd --- /dev/null +++ b/src/axolotl/cli/plugins.py @@ -0,0 +1,36 @@ +"""Module for adding click CLI commands from axolotl plugins.""" + +import logging + +import click + +from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass +from axolotl.logging_config import configure_logging +from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig + +configure_logging() +LOG = logging.getLogger(__name__) + + +def setup_plugin_commands(cli: click.core.Group) -> None: + """ + Setup CLI commands for available plugins. + + Args: + cli: Click CLI object to add plugin CLI options to. + """ + try: + from axolotl_diff_transformer.convert_diff_transformer import do_cli + from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs + + @cli.command() + @click.argument("config", type=click.Path(exists=True, path_type=str)) + @add_options_from_dataclass(ConvertDiffTransformerCliArgs) + @add_options_from_config(AxolotlInputConfig) + def convert_diff_transformer(config: str, **kwargs): + """Convert model attention layers to differential attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + do_cli(config=config, **kwargs) + + except ImportError as exc: + LOG.debug("axolotl-diff-transformer not found: %s", exc) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 8b31b52b5..bebd6b00e 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -4,13 +4,19 @@ shared module for cli specific things import logging from dataclasses import dataclass, field -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +if TYPE_CHECKING: + try: + from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs + except: # noqa: E722 # pylint: disable=bare-except # nosec B110 + pass + configure_logging() LOG = logging.getLogger(__name__) @@ -48,21 +54,10 @@ class EvaluateCliArgs: debug_num_examples: int = field(default=0) -@dataclass -class ConvertDiffTransformerCliArgs: - """dataclass with arguments for convert-diff-transformer CLI""" - - debug: bool = field(default=False) - zero_init: bool = field(default=False) - sublayer_norm: bool = field(default=True) - split_heads: bool = field(default=False) - mirror_weights: bool = field(default=False) - - def load_model_and_tokenizer( *, cfg: DictDefault, - cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs], + cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"], ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md deleted file mode 100644 index ea27e0291..000000000 --- a/src/axolotl/integrations/diff_transformer/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Differential Transformer - -### Installation - -```shell -pip install git+https://github.com/axolotl-ai-cloud/diff-transformer.git -``` - -Editable: - -```shell -git clone git@github.com:axolotl-ai-cloud/diff-transformer.git -cd diff-transformer -pip install -e . -``` - -### Usage - -**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command. - -```yaml -plugins: - - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin - -diff_attention: true -``` - -Additional, optional arguments include: - -```yaml -# How often to log diffential attention-related metrics to wandb -diff_attn_log_every: 100 - -# How many differential attention layers to monitor (strided from 0..k..num_layers) -diff_attn_num_monitor_layers: 3 - -# How many steps to "warmup" the mixing parameter for the negative component of differential attention -# Follows a linear warmup schedule from 0 to 1; if not specified, the mixing component is set to 1 -diff_attn_warmup_steps: 1000 -``` diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py deleted file mode 100644 index 622cb5e1d..000000000 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Definition of differential transformer plugin.""" - -import logging -from typing import List - -from transformers import PreTrainedModel, TrainerCallback - -from axolotl.integrations.base import BasePlugin -from axolotl.utils.callbacks.diff_attn import ( - DifferentialAttentionMixingCallback, - DifferentialAttentionMonitorCallback, -) -from axolotl.utils.dict import DictDefault - -LOG = logging.getLogger(__name__) - - -class DifferentialTransformerPlugin(BasePlugin): - """Plugin for differential transformer integration with Axolotl.""" - - def __init__(self) -> None: - """ - Constructor for differential transformers plugin. Calls `register_diff_attn` - to register differential attention custom modeling implementation to `AutoConfig` - and `AutoModel`. - """ - from axolotl_diff_transformer.modeling.modeling_diff_attn import ( - register_diff_attn, - ) - - register_diff_attn() - - def get_input_args(self) -> str: - """Returns module path to diff transformer plugin args for `axolotl` config.""" - return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" - - # pylint: disable=unused-argument - def add_callbacks_pre_trainer( - self, cfg: DictDefault, model: PreTrainedModel - ) -> List[TrainerCallback]: - """ - Returns `DifferentialAttentionMonitorCallback` to be added to the list of - callbacks for the `axolotl` trainer if wandb usage is enabled. - - Parameters: - cfg: Dictionary mapping `axolotl` config keys to values. - model: The loaded mfodel. - - Returns: - A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`. - """ - callbacks = [] - if cfg.use_wandb: - callbacks.append( - DifferentialAttentionMonitorCallback( - log_every=cfg.diff_attn_log_every, - num_monitor_layers=cfg.diff_attn_num_monitor_layers, - warmup_steps=cfg.diff_attn_warmup_steps, - ) - ) - - if cfg.diff_attn_warmup_steps: - callbacks.append( - DifferentialAttentionMixingCallback( - warmup_steps=cfg.diff_attn_warmup_steps - ) - ) - - return callbacks diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py deleted file mode 100644 index ebd4d03a1..000000000 --- a/src/axolotl/integrations/diff_transformer/args.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Module for handling differential transfomer input arguments.""" - -import logging -from typing import Optional - -from pydantic import BaseModel - -LOG = logging.getLogger(__name__) - - -class DifferentialTransformerArgs(BaseModel): - """ - Input args for differential transformer. - - Attributes: - diff_attention: Whether to use differential attention layers. - diff_attn_log_every: How often to log differential attention statistics. - diff_attn_num_monitor_layers: Number of layers to monitor for attention stats. - diff_attn_warmup_steps: Number of steps to linearly increase negative attention - mixing weight from 0 to 1. If specified, will reach full mixing at this - step. If `None`, negative attention has full weight from the start. - """ - - diff_attention: Optional[bool] = None - diff_attn_log_every: Optional[int] = 100 - diff_attn_num_monitor_layers: Optional[int] = 3 - diff_attn_warmup_steps: Optional[int] = None diff --git a/src/axolotl/utils/callbacks/diff_attn.py b/src/axolotl/utils/callbacks/diff_attn.py deleted file mode 100644 index 3e99e7d5f..000000000 --- a/src/axolotl/utils/callbacks/diff_attn.py +++ /dev/null @@ -1,234 +0,0 @@ -""" -Monitor and log differential attention components during training. - -This module provides a callback for tracking the behavior of differential attention -mechanisms, including lambda parameters and attention statistics. -""" - -from typing import Any - -import torch -import wandb -from torch import nn -from transformers import TrainerCallback - -from axolotl.utils.distributed import is_main_process - - -class DifferentialAttentionMonitorCallback(TrainerCallback): - """ - Callback to monitor differential attention components and lambda parameters. - - This callback tracks attention statistics across all layers and provides detailed - monitoring for a specified number of layers evenly spaced through the model. - """ - - def __init__( - self, - log_every: int = 250, - num_monitor_layers: int = 3, - warmup_steps: int | None = None, - ): - """ - Initialize the differential attention monitor. - - Args: - log_every: Number of steps between logging events. - num_monitor_layers: Number of individual layers to monitor in detail. - warmup_steps: Optional parameter for negative attention component warmup. - """ - self.log_every = log_every - self.num_monitor_layers = num_monitor_layers - self.warmup_steps = warmup_steps - self.monitor_layers: list[int] | None = None # Will be set in on_train_begin - - # pylint: disable=unused-argument - def on_train_begin( - self, - args: Any, - state: Any, - control: Any, - model: torch.nn.Module, - **kwargs, - ) -> None: - """ - Set up layer monitoring at the start of training. - - Args: - args: Training arguments. - state: Training state. - control: Training control object. - model: The model being trained. - **kwargs: Additional arguments passed by the trainer. - """ - if is_main_process(): - num_layers = len(model.model.layers) - self.num_monitor_layers = min(self.num_monitor_layers, num_layers) - - stride = ( - (num_layers - 1) / (self.num_monitor_layers - 1) - if self.num_monitor_layers > 1 - else 0 - ) - self.monitor_layers = [ - round(i * stride) for i in range(self.num_monitor_layers) - ] - print(f"Monitoring layers {self.monitor_layers} in detail") - - # pylint: disable=unused-argument - def on_step_end( - self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs - ) -> None: - """ - Log attention metrics at the end of each step. - - Collects and logs: - - Lambda parameter norms and values. - - Attention statistics (mean and std). - - Both per-layer and aggregate metrics. - - Args: - args: Training arguments. - state: Training state. - control: Training control object. - model: The model being trained. - **kwargs: Additional arguments passed by the trainer. - """ - if not is_main_process() or state.global_step % self.log_every != 0: - return - - assert self.monitor_layers is not None - - # Aggregate stats across all layers - all_q1_norms = [] - all_q2_norms = [] - all_k1_norms = [] - all_k2_norms = [] - all_lambda1 = [] - all_lambda2 = [] - all_lambda_full = [] - - metrics = {} - for layer_idx, layer in enumerate(model.model.layers): - attn = layer.self_attn - - # Collect stats for aggregation - all_q1_norms.append(attn.lambda_q1.norm().item()) - all_q2_norms.append(attn.lambda_q2.norm().item()) - all_k1_norms.append(attn.lambda_k1.norm().item()) - all_k2_norms.append(attn.lambda_k2.norm().item()) - - lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item() - lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item() - all_lambda1.append(lambda1) - all_lambda2.append(lambda2) - all_lambda_full.append(attn.lambda_full) - - # Log detailed metrics for monitored layers - if layer_idx in self.monitor_layers: - metrics.update( - { - f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(), - f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(), - f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(), - f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(), - f"layer_{layer_idx}/lambda1": lambda1, - f"layer_{layer_idx}/lambda2": lambda2, - f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(), - f"layer_{layer_idx}/lambda_full": lambda1 - - lambda2 - + attn.lambda_init.item(), - f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(), - f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(), - f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(), - f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(), - } - ) - - # Add aggregate metrics - metrics.update( - { - "aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms) - .mean() - .item(), - "aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(), - "aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms) - .mean() - .item(), - "aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(), - "aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms) - .mean() - .item(), - "aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(), - "aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms) - .mean() - .item(), - "aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(), - "aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(), - "aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(), - "aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(), - "aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(), - "aggregate/lambda_full_mean": torch.tensor(all_lambda_full) - .mean() - .item(), - "aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(), - } - ) - - if self.warmup_steps: - metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix - - wandb.log(metrics, step=state.global_step) - - -class DifferentialAttentionMixingCallback(TrainerCallback): - """ - Callback to gradually increase the weight of negative attention components during - training. - """ - - def __init__(self, warmup_steps: int): - """ - Args: - warmup_steps: Number of steps to linearly increase negative attention - weight from 0 to 1. If `None`, negative attention has full weight from - start. - """ - self.warmup_steps = warmup_steps - self.diff_attention_layers: list[nn.Module] | None = None - - # pylint: disable=unused-argument - def on_train_begin( - self, - args: Any, - state: Any, - control: Any, - model: torch.nn.Module, - **kwargs, - ) -> None: - """Cache the differential attention layers at the start of training.""" - if model is not None: - # Get the actual model if it's wrapped - if hasattr(model, "module"): - model = model.module - - # Cache all differential attention layers - self.diff_attention_layers = [ - module for module in model.modules() if hasattr(module, "diff_attn_mix") - ] - - def on_step_begin( - self, - args: Any, - state: Any, - control: Any, - model: torch.nn.Module = None, - **kwargs, - ) -> None: - if self.diff_attention_layers and self.warmup_steps: - # Calculate mixing parameter (0 to 1) - mix = min(1.0, state.global_step / self.warmup_steps) - - # Update cached layers - for layer in self.diff_attention_layers: - layer.diff_attn_mix = mix diff --git a/tests/e2e/integrations/convert_diff_transformer/__init__.py b/tests/e2e/integrations/convert_diff_transformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/e2e/integrations/convert_diff_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py deleted file mode 100644 index 3964df052..000000000 --- a/tests/e2e/integrations/convert_diff_transformer/conftest.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Shared fixtures for differential transformer conversion tests.""" - -import pytest -from click.testing import CliRunner - - -@pytest.fixture(scope="class") -def base_config(): - """Basic config for testing.""" - return { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "datasets": [ - { - "path": "axolotl-ai-co/alpaca_100_test", - "type": "alpaca", - }, - ], - "gradient_accumulation_steps": 1, - "learning_rate": 1e-4, - "val_set_size": 0.1, - "micro_batch_size": 1, - "sequence_len": 2048, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - } - - -@pytest.fixture(scope="class") -def cli_runner(): - return CliRunner() diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py deleted file mode 100644 index d5915f8a5..000000000 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py +++ /dev/null @@ -1,51 +0,0 @@ -"""End-to-end tests for differential transformer conversion and evaluation.""" -# pylint: disable=duplicate-code - -from pathlib import Path - -import yaml -from pytest import approx - -from axolotl.cli import load_cfg -from axolotl.cli.evaluate import do_evaluate -from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer -from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs - - -def test_conversion_and_eval_cli(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - eval_cfg = load_cfg(str(output_dir)) - eval_cli_args = EvaluateCliArgs() - all_metrics = do_evaluate(eval_cfg, eval_cli_args) - - assert list(all_metrics.keys()) == [ - "train_loss", - "train_model_preparation_time", - "train_runtime", - "train_samples_per_second", - "train_steps_per_second", - "eval_loss", - "eval_model_preparation_time", - "eval_runtime", - "eval_samples_per_second", - "eval_steps_per_second", - ] - assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) - assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4) diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py deleted file mode 100644 index e1ad31fdd..000000000 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ /dev/null @@ -1,150 +0,0 @@ -"""End-to-end tests for differential transformer conversion.""" -# pylint: disable=redefined-outer-name -# pylint: disable=duplicate-code - -from pathlib import Path -from typing import Optional -from unittest.mock import patch - -import pytest -import yaml - -from axolotl.cli import load_cfg -from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer -from axolotl.cli.main import cli -from axolotl.common.cli import ConvertDiffTransformerCliArgs - - -def test_cli_validation(cli_runner): - # Test missing config file - result = cli_runner.invoke(cli, ["convert-diff-transformer"]) - assert result.exit_code != 0 - assert "Error: Missing argument 'CONFIG'." in result.output - - # Test non-existent config file - result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) - assert result.exit_code != 0 - assert "Error: Invalid value for 'CONFIG'" in result.output - - -def test_basic_execution(cli_runner, tmp_path: Path, base_config): - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - with patch( - "axolotl.cli.integrations.convert_diff_transformer.do_cli" - ) as mock_do_cli: - result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) - assert result.exit_code == 0 - - mock_do_cli.assert_called_once() - assert mock_do_cli.call_args.kwargs["config"] == str(config_path) - - -def test_conversion_cli_basic(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs() - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_debug(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert not debug_info["generations_match"] - assert not debug_info["match_expected"] - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -def test_conversion_cli_reproduce(tmp_path: Path, base_config): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_repoduce_attentions( - tmp_path: Path, base_config, attention: Optional[str] -): - output_dir = tmp_path / "converted" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs( - debug=True, zero_init=True, sublayer_norm=False - ) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is True - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists() - - -@pytest.mark.parametrize( - "attention", ["eager_attention", "sdp_attention", "flash_attention"] -) -def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): - output_dir = tmp_path / "converted" - - # Smallest model with an even number of attention heads - base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" - base_config["output_dir"] = str(output_dir) - base_config[attention] = True - - config_path = tmp_path / "config.yml" - with open(config_path, "w", encoding="utf-8") as file: - yaml.dump(base_config, file) - - cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) - - assert debug_info["generations_match"] is False - assert (output_dir / "model.safetensors").exists() - assert (output_dir / "config.json").exists() - assert (output_dir / "axolotl_config.yml").exists()