From 974dc00a7d966e1b26c7e69aea378c8d325776c8 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Sun, 11 Jun 2023 14:00:54 +0900 Subject: [PATCH] Fix set mem_id for inference and refactor --- scripts/finetune.py | 3 +++ src/axolotl/monkeypatch/llama_landmark_attn.py | 10 ++++++++++ src/axolotl/utils/trainer.py | 11 +++++++---- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index cdc4b5e0e..4875256ba 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"): ) if cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id + + set_model_mem_id(model, tokenizer) model.set_mem_cache_args( max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None ) diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 51f1b90fe..2a4cdbc36 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -29,6 +29,7 @@ import torch import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss +from transformers import LlamaTokenizer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn(): transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb + + +def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer): + mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) + model.set_mem_id(mem_id) + + +def get_mem_id(tokenizer: LlamaTokenizer): + return tokenizer.convert_tokens_to_ids(MEM_TOKEN) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ae1e7e93..1250ad4f6 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): if cfg.is_llama_derived_model and cfg.landmark_attention: from functools import partial - from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens + from axolotl.monkeypatch.llama_landmark_attn import ( + add_mem_tokens, + get_mem_id, + set_model_mem_id, + ) - mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) - model.set_mem_id(mem_id) + set_model_mem_id(model, tokenizer) logging.info("Adding landmark attention tokens to dataset") for dataset in [train_dataset, eval_dataset]: dataset = dataset.map( - partial(add_mem_tokens, mem_freq=50, mem_id=mem_id), + partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)), batched=False, num_proc=32, )