diff --git a/scripts/finetune.py b/scripts/finetune.py index 7b2dc77c8..e807456d8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Union import fire import torch import yaml +from transformers import GenerationConfig from axolotl.utils.data import load_prepare_datasets from axolotl.utils.dict import DictDefault @@ -73,26 +74,33 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): instruction = get_multi_line_input() if not instruction: return - prompt: str = next(prompter_module().build_prompt(instruction=instruction)) + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() with torch.no_grad(): - # gc = GenerationConfig() # TODO swap out and use this - generated = model.generate( - inputs=batch["input_ids"].to(cfg.device), - do_sample=True, - use_cache=True, + generation_config = GenerationConfig( repetition_penalty=1.1, - max_new_tokens=100, + max_new_tokens=1024, temperature=0.9, top_p=0.95, top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, return_dict_in_generate=True, output_attentions=False, output_hidden_states=False, output_scores=False, ) + generated = model.generate( + inputs=batch["input_ids"].to(cfg.device), + generation_config=generation_config, + ) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))