adding CLI command for convert-diff-transformer
This commit is contained in:
0
src/axolotl/cli/integrations/__init__.py
Normal file
0
src/axolotl/cli/integrations/__init__.py
Normal file
131
src/axolotl/cli/integrations/convert_diff_transformer.py
Normal file
131
src/axolotl/cli/integrations/convert_diff_transformer.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""CLI to convert a transformers model's attns to diff attns."""
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from time import time
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from colorama import Fore
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from transformers import HfArgumentParser
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
|
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.convert_attention")
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||||
|
"""Run test inference and return generation time"""
|
||||||
|
try:
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
inputs = {
|
||||||
|
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=20,
|
||||||
|
num_beams=1,
|
||||||
|
do_sample=False,
|
||||||
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
elapsed = time() - start
|
||||||
|
|
||||||
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
LOG.info("Prompt: %s", prompt)
|
||||||
|
LOG.info("Generated: %s", generated_text)
|
||||||
|
LOG.info("Generation time: %.2fs", elapsed)
|
||||||
|
|
||||||
|
return elapsed, generated_text
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.error("Inference failed: %s", str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
print_axolotl_text_art()
|
||||||
|
|
||||||
|
cfg = load_cfg(config, **kwargs)
|
||||||
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
|
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load model and tokenizer
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
|
# Log original model info
|
||||||
|
LOG.info(
|
||||||
|
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
|
||||||
|
model.config.hidden_size,
|
||||||
|
model.config.num_attention_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test original model
|
||||||
|
LOG.info("Testing original model...")
|
||||||
|
orig_time, orig_text = test_inference(model, tokenizer)
|
||||||
|
|
||||||
|
# Convert attention
|
||||||
|
LOG.info("Converting to differential attention...")
|
||||||
|
try:
|
||||||
|
model = convert_to_diff_attention(model)
|
||||||
|
model.to(model.device)
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Test converted model
|
||||||
|
LOG.info("Testing converted model...")
|
||||||
|
conv_time, conv_text = test_inference(model, tokenizer)
|
||||||
|
|
||||||
|
# Save if requested
|
||||||
|
if cfg.output_dir:
|
||||||
|
LOG.info("Saving converted model to %s", cfg.output_dir)
|
||||||
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ "Conversion successful!\n"
|
||||||
|
+ f"Original generation time: {orig_time:.2f}s\n"
|
||||||
|
+ f"Converted generation time: {conv_time:.2f}s"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
if orig_text == conv_text:
|
||||||
|
LOG.info(
|
||||||
|
Fore.GREEN
|
||||||
|
+ "Generations match!\n"
|
||||||
|
+ f"Model generation: {orig_text}\n"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
Fore.RED
|
||||||
|
+ "Generations do not match.\n"
|
||||||
|
+ f"Original generation: {orig_text}\n"
|
||||||
|
+ f"Converted generation: {conv_text}\n"
|
||||||
|
+ Fore.RESET
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -240,6 +240,24 @@ def merge_lora(
|
|||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--output-dir",
|
||||||
|
type=click.Path(path_type=str),
|
||||||
|
help="Directory to save converted model",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def convert_diff_transformer(config: str, **kwargs):
|
||||||
|
"""Convert model attention layers to differential attention layers."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
from axolotl.cli.integrations.convert_diff_transformer import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||||
@click.option("--dest", help="Destination directory")
|
@click.option("--dest", help="Destination directory")
|
||||||
|
|||||||
Reference in New Issue
Block a user