Feat: Added Gradio support (#812)

* Added gradio support

* queuing and title

* pre-commit run
This commit is contained in:
Jason Stillerman
2023-11-04 23:59:22 -04:00
committed by GitHub
parent cdc71f73c8
commit 738a057674
4 changed files with 108 additions and 4 deletions

View File

@@ -97,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --lora_model_dir="./lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
``` ```
## Installation ## Installation
@@ -919,6 +923,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below: Please use `--sample_packing False` if you have it on and receive the error similar to below:

View File

@@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
gradio

View File

@@ -6,8 +6,10 @@ import os
import random import random
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import gradio as gr
import torch import torch
import yaml import yaml
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
from art import text2art from art import text2art
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -153,6 +155,91 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path): def choose_config(path: Path):
yaml_files = list(path.glob("*.yml")) yaml_files = list(path.glob("*.yml"))

View File

@@ -6,11 +6,16 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
) )
parsed_cli_args.inference = True parsed_cli_args.inference = True
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":