From df1504ae1481fd09fbf60075b71d0b5feab5608b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 11 Dec 2024 23:11:19 -0500 Subject: [PATCH] adding CLI command for convert-diff-transformer --- src/axolotl/cli/integrations/__init__.py | 0 .../integrations/convert_diff_transformer.py | 131 ++++++++++++++++++ src/axolotl/cli/main.py | 18 +++ 3 files changed, 149 insertions(+) create mode 100644 src/axolotl/cli/integrations/__init__.py create mode 100644 src/axolotl/cli/integrations/convert_diff_transformer.py diff --git a/src/axolotl/cli/integrations/__init__.py b/src/axolotl/cli/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py new file mode 100644 index 000000000..8886c4946 --- /dev/null +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -0,0 +1,131 @@ +"""CLI to convert a transformers model's attns to diff attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention + +LOG = logging.getLogger("axolotl.cli.convert_attention") + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + try: + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + LOG.info("Testing original model...") + orig_time, orig_text = test_inference(model, tokenizer) + + # Convert attention + LOG.info("Converting to differential attention...") + try: + model = convert_to_diff_attention(model) + model.to(model.device) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + LOG.info("Testing converted model...") + conv_time, conv_text = test_inference(model, tokenizer) + + # Save if requested + if cfg.output_dir: + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {orig_time:.2f}s\n" + + f"Converted generation time: {conv_time:.2f}s" + + Fore.RESET + ) + + if orig_text == conv_text: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + f"Model generation: {orig_text}\n" + + Fore.RESET + ) + else: + LOG.info( + Fore.RED + + "Generations do not match.\n" + + f"Original generation: {orig_text}\n" + + f"Converted generation: {conv_text}\n" + + Fore.RESET + ) + + except Exception as exc: + LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) + raise + + +if __name__ == "__main__": + load_dotenv() + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + ) + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..7743d5017 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -240,6 +240,24 @@ def merge_lora( do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option( + "--output-dir", + type=click.Path(path_type=str), + help="Directory to save converted model", +) +@add_options_from_dataclass(TrainerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def convert_diff_transformer(config: str, **kwargs): + """Convert model attention layers to differential attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.integrations.convert_diff_transformer import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory")