improve inference
This commit is contained in:
@@ -79,31 +79,31 @@ def do_inference(cfg, model, tokenizer):
|
|||||||
|
|
||||||
from axolotl.prompters import ReflectAlpacaPrompter
|
from axolotl.prompters import ReflectAlpacaPrompter
|
||||||
|
|
||||||
instruction = str(input("Give me an instruction: "))
|
while True:
|
||||||
instruction = (
|
instruction = str(input("Give me an instruction: "))
|
||||||
instruction if not instruction else "Tell me a joke about dromedaries."
|
if not instruction:
|
||||||
)
|
return
|
||||||
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# gc = GenerationConfig() # TODO swap out and use this
|
# gc = GenerationConfig() # TODO swap out and use this
|
||||||
generated = model.generate(
|
generated = model.generate(
|
||||||
inputs=batch["input_ids"].to("cuda"),
|
inputs=batch["input_ids"].to("cuda"),
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
repetition_penalty=1.1,
|
repetition_penalty=1.1,
|
||||||
max_new_tokens=100,
|
max_new_tokens=100,
|
||||||
temperature=0.9,
|
temperature=0.9,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
top_k=40,
|
top_k=40,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
)
|
)
|
||||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||||
|
|
||||||
|
|
||||||
def choose_config(path: Path):
|
def choose_config(path: Path):
|
||||||
|
|||||||
@@ -66,22 +66,25 @@ def load_model(
|
|||||||
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
snapshot_download_kwargs = {}
|
try:
|
||||||
if cfg.base_model_ignore_patterns:
|
snapshot_download_kwargs = {}
|
||||||
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
if cfg.base_model_ignore_patterns:
|
||||||
cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
|
snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
|
||||||
files = (
|
cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
|
||||||
list(cache_model_path.glob("*.pt"))
|
files = (
|
||||||
+ list(cache_model_path.glob("*.safetensors"))
|
list(cache_model_path.glob("*.pt"))
|
||||||
+ list(cache_model_path.glob("*.bin"))
|
+ list(cache_model_path.glob("*.safetensors"))
|
||||||
)
|
+ list(cache_model_path.glob("*.bin"))
|
||||||
if len(files) > 0:
|
|
||||||
model_path = str(files[0])
|
|
||||||
else:
|
|
||||||
logging.warning(
|
|
||||||
"unable to find a cached model file, this will likely fail..."
|
|
||||||
)
|
)
|
||||||
model_path = str(cache_model_path)
|
if len(files) > 0:
|
||||||
|
model_path = str(files[0])
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
"unable to find a cached model file, this will likely fail..."
|
||||||
|
)
|
||||||
|
model_path = str(cache_model_path)
|
||||||
|
except:
|
||||||
|
model_path = cfg.base_model
|
||||||
model, tokenizer = load_llama_model_4bit_low_ram(
|
model, tokenizer = load_llama_model_4bit_low_ram(
|
||||||
base_model_config if base_model_config else base_model,
|
base_model_config if base_model_config else base_model,
|
||||||
model_path,
|
model_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user