adding CLI command for convert-diff-transformer
This commit is contained in:
0
src/axolotl/cli/integrations/__init__.py
Normal file
0
src/axolotl/cli/integrations/__init__.py
Normal file
131
src/axolotl/cli/integrations/convert_diff_transformer.py
Normal file
131
src/axolotl/cli/integrations/convert_diff_transformer.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""CLI to convert a transformers model's attns to diff attns."""
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from time import time
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from colorama import Fore
|
||||
from dotenv import load_dotenv
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.convert_attention")
|
||||
|
||||
|
||||
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
"""Run test inference and return generation time"""
|
||||
try:
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
inputs = {
|
||||
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
||||
}
|
||||
|
||||
start = time()
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=20,
|
||||
num_beams=1,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
)
|
||||
elapsed = time() - start
|
||||
|
||||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
LOG.info("Prompt: %s", prompt)
|
||||
LOG.info("Generated: %s", generated_text)
|
||||
LOG.info("Generation time: %.2fs", elapsed)
|
||||
|
||||
return elapsed, generated_text
|
||||
|
||||
except Exception as exc:
|
||||
LOG.error("Inference failed: %s", str(exc))
|
||||
raise
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
# Log original model info
|
||||
LOG.info(
|
||||
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
|
||||
model.config.hidden_size,
|
||||
model.config.num_attention_heads,
|
||||
)
|
||||
|
||||
# Test original model
|
||||
LOG.info("Testing original model...")
|
||||
orig_time, orig_text = test_inference(model, tokenizer)
|
||||
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
try:
|
||||
model = convert_to_diff_attention(model)
|
||||
model.to(model.device)
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
# Test converted model
|
||||
LOG.info("Testing converted model...")
|
||||
conv_time, conv_text = test_inference(model, tokenizer)
|
||||
|
||||
# Save if requested
|
||||
if cfg.output_dir:
|
||||
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Conversion successful!\n"
|
||||
+ f"Original generation time: {orig_time:.2f}s\n"
|
||||
+ f"Converted generation time: {conv_time:.2f}s"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
if orig_text == conv_text:
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ "Generations match!\n"
|
||||
+ f"Model generation: {orig_text}\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
Fore.RED
|
||||
+ "Generations do not match.\n"
|
||||
+ f"Original generation: {orig_text}\n"
|
||||
+ f"Converted generation: {conv_text}\n"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
load_dotenv()
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
fire.Fire(do_cli)
|
||||
@@ -240,6 +240,24 @@ def merge_lora(
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option(
|
||||
"--output-dir",
|
||||
type=click.Path(path_type=str),
|
||||
help="Directory to save converted model",
|
||||
)
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def convert_diff_transformer(config: str, **kwargs):
|
||||
"""Convert model attention layers to differential attention layers."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||
@click.option("--dest", help="Destination directory")
|
||||
|
||||
Reference in New Issue
Block a user