diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py deleted file mode 100644 index 6eb00452b..000000000 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ /dev/null @@ -1,185 +0,0 @@ -"""CLI to convert a transformers model's attns to diff attns.""" -import logging -import warnings -from pathlib import Path -from time import time -from typing import Union - -import fire -import torch -import yaml -from colorama import Fore -from dotenv import load_dotenv -from transformers import HfArgumentParser - -from axolotl.cli import load_cfg, print_axolotl_text_art -from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import ( - convert_to_diff_attention, -) - -LOG = logging.getLogger("axolotl.cli.convert_attention") - - -def test_inference(model, tokenizer, prompt="The quick brown fox"): - """Run test inference and return generation time""" - try: - inputs = tokenizer(prompt, return_tensors="pt") - inputs = { - k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() - } - - start = time() - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=20, - num_beams=1, - do_sample=False, - pad_token_id=tokenizer.pad_token_id, - use_cache=False, - ) - elapsed = time() - start - - generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) - LOG.info("Prompt: %s", prompt) - LOG.info("Generated: %s", generated_text) - LOG.info("Generation time: %.2fs", elapsed) - - return elapsed, generated_text - - except Exception as exc: - LOG.error("Inference failed: %s", str(exc)) - raise - - -def convert_diff_transformer(cfg, cli_args, config_path): - try: - # Load model and tokenizer - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) - model.to(cfg.device, dtype=cfg.torch_dtype) - - # Log original model info - LOG.info( - "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", - model.config.hidden_size, - model.config.num_attention_heads, - ) - - # Test original model - if cli_args.debug: - LOG.info("Testing original model...") - orig_time, orig_text = test_inference(model, tokenizer) - - # Convert attention - LOG.info("Converting to differential attention...") - try: - model = convert_to_diff_attention( - model=model, - zero_init=cli_args.zero_init, - sublayer_norm=cli_args.sublayer_norm, - ) - model.to(cfg.device, dtype=cfg.torch_dtype) - except Exception as exc: - LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) - raise - - # Test converted model - if cli_args.debug: - LOG.info("Testing converted model...") - conv_time, conv_text = test_inference(model, tokenizer) - - # Save if requested - if cfg.output_dir: - # Save model and tokenizer - LOG.info("Saving converted model to %s", cfg.output_dir) - model.save_pretrained(cfg.output_dir) - tokenizer.save_pretrained(cfg.output_dir) - - # Modify config to reflect new path / differential attention - output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" - LOG.info("Saving updated config to %s", output_config_path) - - with open(config_path, "r", encoding="utf-8") as file: - data = yaml.safe_load(file) or {} - - data["base_model"] = cfg.output_dir - data["diff_attention"] = True - - with open(output_config_path, "w", encoding="utf-8") as file: - yaml.dump(data, file) - else: - LOG.info("Not saving converted model to disk") - LOG.info("Pass --output-dir path/to/save to save model") - - if cli_args.debug: - LOG.info( - Fore.GREEN - + "Conversion successful!\n" - + f"Original generation time: {orig_time:.2f}s\n" - + f"Converted generation time: {conv_time:.2f}s" - + Fore.RESET - ) - - if orig_text == conv_text: - LOG.info( - Fore.GREEN - + "Generations match!\n" - + "Model generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + Fore.RESET - ) - else: - message = ( - "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - ) - - if cli_args.zero_init and not cli_args.sublayer_norm: - LOG.info(Fore.RED + message + Fore.RESET) - else: - LOG.info( - Fore.YELLOW - + message - + "However, this is expected since --zero-init" - + " and --no-sublayer-norm were not passed." - + Fore.RESET - ) - - return model - - except Exception as exc: - LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) - raise - - -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - print_axolotl_text_art() - - cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(ConvertDiffTransformerCliArgs) - cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - - convert_diff_transformer(cfg, cli_args, config) - - -if __name__ == "__main__": - load_dotenv() - fire.Fire(do_cli) diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_differential_transformer.py new file mode 100644 index 000000000..8903da6d1 --- /dev/null +++ b/src/axolotl/cli/integrations/convert_differential_transformer.py @@ -0,0 +1,190 @@ +"""CLI to convert a transformers model's attns to diff attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +import yaml +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer +from axolotl.integrations.differential_transformer.convert import ( + convert_to_diff_attention, +) + +LOG = logging.getLogger(__name__) + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def convert_differential_transformer(cfg, cli_args, config_path): + debug_info = {} + + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + if cli_args.debug: + LOG.info("Testing original model...") + debug_info["orig_time"], debug_info["orig_text"] = test_inference( + model, tokenizer + ) + + # Convert attention + LOG.info("Converting to differential attention...") + try: + model = convert_to_diff_attention( + model=model, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + ) + model.to(cfg.device, dtype=cfg.torch_dtype) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + if cli_args.debug: + LOG.info("Testing converted model...") + debug_info["conv_time"], debug_info["conv_text"] = test_inference( + model, tokenizer + ) + + # Save if requested + if cfg.output_dir: + # Save model and tokenizer + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} + + data["base_model"] = cfg.output_dir + data["diff_attention"] = True + + with open(output_config_path, "w", encoding="utf-8") as file: + yaml.dump(data, file) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") + + if cli_args.debug: + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {debug_info['orig_time']:.2f}s\n" + + f"Converted generation time: {debug_info['conv_time']:.2f}s" + + Fore.RESET + ) + + if debug_info["orig_text"] == debug_info["conv_text"]: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + debug_info["generations_match"] = True + else: + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['conv_text']}\n" + + "*" * 50 + + "\n" + ) + debug_info["generations_match"] = False + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) + debug_info["match_expected"] = True + else: + LOG.info( + Fore.YELLOW + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + + Fore.RESET + ) + debug_info["match_expected"] = False + + return model, debug_info + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_differential_transformer(cfg, cli_args, config) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index f3032a9d6..5cb88a1ea 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -249,11 +249,11 @@ def merge_lora( @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) -def convert_diff_transformer(config: str, **kwargs): +def convert_differential_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.integrations.convert_diff_transformer import do_cli + from axolotl.cli.integrations.convert_differential_transformer import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 9c921e564..2d6a5bb31 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -57,7 +57,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: """ - dataclass with arguments for convert-diff-transformer CLI + dataclass with arguments for convert-differential-transformer CLI """ debug: bool = field(default=False) diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py index 5620ad199..ce3773037 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -17,7 +17,6 @@ from .differential_attention import ( LlamaDifferentialSdpaAttention, ) -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) ATTENTION_MAPPING = { diff --git a/tests/cli/integrations/__init__.py b/tests/cli/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cli/integrations/test_cli_convert_differential_transformer.py b/tests/cli/integrations/test_cli_convert_differential_transformer.py new file mode 100644 index 000000000..cd2a464c6 --- /dev/null +++ b/tests/cli/integrations/test_cli_convert_differential_transformer.py @@ -0,0 +1,48 @@ +"""Tests for convert-differential-transformer CLI command.""" + +from pathlib import Path +from unittest.mock import patch + +from axolotl.cli.main import cli + + +def test_cli_validation(cli_runner): + """Test CLI validation for a command. + + Args: + cli_runner: CLI runner fixture + """ + # Test missing config file + result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke( + cli, ["convert-differential-transformer", "nonexistent.yml"] + ) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, valid_test_config: str): + """Test basic execution. + + Args: + cli_runner: CLI runner fixture + tmp_path: Temporary path fixture + valid_test_config: Valid config fixture + """ + config_path = tmp_path / "config.yml" + config_path.write_text(valid_test_config) + + with patch( + "axolotl.cli.integrations.convert_differential_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke( + cli, ["convert-differential-transformer", str(config_path)] + ) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) diff --git a/tests/e2e/integrations/test_convert_differential_transformer.py b/tests/e2e/integrations/test_convert_differential_transformer.py new file mode 100644 index 000000000..da3aac11a --- /dev/null +++ b/tests/e2e/integrations/test_convert_differential_transformer.py @@ -0,0 +1,127 @@ +"""End-to-end tests for differential transformer conversion.""" +# pylint: disable=redefined-outer-name + +from pathlib import Path +from typing import Optional + +import pytest +import yaml + +from axolotl.cli import load_cfg +from axolotl.cli.integrations.convert_differential_transformer import ( + convert_differential_transformer, +) +from axolotl.common.cli import ConvertDiffTransformerCliArgs + + +@pytest.fixture() +def base_config(): + """Basic config for testing.""" + return { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "val_set_size": 0.1, + "micro_batch_size": 1, + "sequence_len": 2048, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } + + +def test_conversion_cli_basic(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + # Load config the same way do_cli does + cfg = load_cfg(str(config_path)) + + # Create CLI args + cli_args = ConvertDiffTransformerCliArgs() + + # Call convert_differential_transformer directly + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_debug(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + # Load config the same way do_cli does + cfg = load_cfg(str(config_path)) + + # Create CLI args + cli_args = ConvertDiffTransformerCliArgs(debug=True) + + # Call convert_differential_transformer directly + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_reproduce(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"]) +def test_conversion_cli_repoduce_attentions( + tmp_path: Path, base_config, attention: Optional[str] +): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()