Set mem cache args on inference
This commit is contained in:
@@ -77,6 +77,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.landmark_attention:
|
||||||
|
model.set_mem_cache_args(
|
||||||
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
# support for multiline inputs
|
# support for multiline inputs
|
||||||
@@ -90,6 +95,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Reference in New Issue
Block a user