Fix set mem_id for inference and refactor
This commit is contained in:
@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user