Set mem cache args on inference
This commit is contained in:
@@ -77,6 +77,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
while True:
|
||||
print("=" * 80)
|
||||
# support for multiline inputs
|
||||
@@ -90,6 +95,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user