Merge pull request #177 from NanoCode012/fix/landmark-patch

Fix landmark attention patch
This commit is contained in:
Wing Lian
2023-06-12 08:27:12 -04:00
committed by GitHub
4 changed files with 79 additions and 420 deletions

View File

@@ -77,6 +77,14 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
while True:
print("=" * 80)
# support for multiline inputs
@@ -90,6 +98,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():