improve inference

This commit is contained in:
Wing Lian
2023-04-19 12:57:27 -04:00
parent 5749eb0a1c
commit d65385912e
2 changed files with 42 additions and 39 deletions

View File

@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
from axolotl.prompters import ReflectAlpacaPrompter from axolotl.prompters import ReflectAlpacaPrompter
instruction = str(input("Give me an instruction: ")) while True:
instruction = ( instruction = str(input("Give me an instruction: "))
instruction if not instruction else "Tell me a joke about dromedaries." if not instruction:
) return
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this # gc = GenerationConfig() # TODO swap out and use this
generated = model.generate( generated = model.generate(
inputs=batch["input_ids"].to("cuda"), inputs=batch["input_ids"].to("cuda"),
do_sample=True, do_sample=True,
use_cache=True, use_cache=True,
repetition_penalty=1.1, repetition_penalty=1.1,
max_new_tokens=100, max_new_tokens=100,
temperature=0.9, temperature=0.9,
top_p=0.95, top_p=0.95,
top_k=40, top_k=40,
return_dict_in_generate=True, return_dict_in_generate=True,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
output_scores=False, output_scores=False,
) )
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path): def choose_config(path: Path):

View File

@@ -66,22 +66,25 @@ def load_model(
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
snapshot_download_kwargs = {} try:
if cfg.base_model_ignore_patterns: snapshot_download_kwargs = {}
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns if cfg.base_model_ignore_patterns:
cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs)) snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
files = ( cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
list(cache_model_path.glob("*.pt")) files = (
+ list(cache_model_path.glob("*.safetensors")) list(cache_model_path.glob("*.pt"))
+ list(cache_model_path.glob("*.bin")) + list(cache_model_path.glob("*.safetensors"))
) + list(cache_model_path.glob("*.bin"))
if len(files) > 0:
model_path = str(files[0])
else:
logging.warning(
"unable to find a cached model file, this will likely fail..."
) )
model_path = str(cache_model_path) if len(files) > 0:
model_path = str(files[0])
else:
logging.warning(
"unable to find a cached model file, this will likely fail..."
)
model_path = str(cache_model_path)
except:
model_path = cfg.base_model
model, tokenizer = load_llama_model_4bit_low_ram( model, tokenizer = load_llama_model_4bit_low_ram(
base_model_config if base_model_config else base_model, base_model_config if base_model_config else base_model,
model_path, model_path,