diff --git a/scripts/finetune.py b/scripts/finetune.py index 8a458890c..cdc4b5e0e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -77,6 +77,11 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): importlib.import_module("axolotl.prompters"), prompter ) + if cfg.landmark_attention: + model.set_mem_cache_args( + max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None + ) + while True: print("=" * 80) # support for multiline inputs @@ -90,6 +95,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): else: prompt = instruction.strip() batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + print("=" * 40) model.eval() with torch.no_grad():