diff --git a/scripts/finetune.py b/scripts/finetune.py index 898f88c2c..0f17054ce 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union import fire import torch import yaml -from transformers import GenerationConfig +from transformers import GenerationConfig, TextStreamer from axolotl.utils.data import load_prepare_datasets from axolotl.utils.dict import DictDefault @@ -64,13 +64,21 @@ def get_multi_line_input() -> Optional[str]: def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): - tokenizer.add_special_tokens({"unk_token": ""}) - tokenizer.add_special_tokens({"bos_token": ""}) - tokenizer.add_special_tokens({"eos_token": ""}) + default_tokens = { + "unk_token": "", + "bos_token": "", + "eos_token": "" + } + + for token, symbol in default_tokens.items(): + # If the token isn't already specified in the config, add it + if not (cfg.special_tokens and token in cfg.special_tokens): + tokenizer.add_special_tokens({token: symbol}) prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter) while True: + print("=" * 80) # support for multiline inputs instruction = get_multi_line_input() if not instruction: @@ -79,7 +87,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): prompter_module().build_prompt(instruction=instruction.strip("\n")) ) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) - + print("=" * 40) model.eval() with torch.no_grad(): generation_config = GenerationConfig( @@ -98,10 +106,13 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): output_hidden_states=False, output_scores=False, ) + streamer = TextStreamer(tokenizer) generated = model.generate( inputs=batch["input_ids"].to(cfg.device), generation_config=generation_config, + streamer=streamer, ) + print("=" * 40) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))