Fix set mem_id for inference and refactor

This commit is contained in:
NanoCode012
2023-06-11 14:00:54 +09:00
parent 572d1141e6
commit 974dc00a7d
3 changed files with 20 additions and 4 deletions

View File

@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)