adding CLI command for convert-diff-transformer

This commit is contained in:
Dan Saunders
2024-12-11 23:11:19 -05:00
parent 7be0d7496c
commit df1504ae14
3 changed files with 149 additions and 0 deletions

View File

View File

@@ -0,0 +1,131 @@
"""CLI to convert a transformers model's attns to diff attns."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
LOG = logging.getLogger("axolotl.cli.convert_attention")
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
try:
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
LOG.info("Testing original model...")
orig_time, orig_text = test_inference(model, tokenizer)
# Convert attention
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(model)
model.to(model.device)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
LOG.info("Testing converted model...")
conv_time, conv_text = test_inference(model, tokenizer)
# Save if requested
if cfg.output_dir:
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {orig_time:.2f}s\n"
+ f"Converted generation time: {conv_time:.2f}s"
+ Fore.RESET
)
if orig_text == conv_text:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ f"Model generation: {orig_text}\n"
+ Fore.RESET
)
else:
LOG.info(
Fore.RED
+ "Generations do not match.\n"
+ f"Original generation: {orig_text}\n"
+ f"Converted generation: {conv_text}\n"
+ Fore.RESET
)
except Exception as exc:
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
raise
if __name__ == "__main__":
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
fire.Fire(do_cli)

View File

@@ -240,6 +240,24 @@ def merge_lora(
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save converted model",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_diff_transformer import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")