Fix set mem_id for inference and refactor

This commit is contained in:
NanoCode012
2023-06-11 14:00:54 +09:00
parent 572d1141e6
commit 974dc00a7d
3 changed files with 20 additions and 4 deletions

View File

@@ -29,6 +29,7 @@ import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import LlamaTokenizer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
def get_mem_id(tokenizer: LlamaTokenizer):
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)

View File

@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
set_model_mem_id(model, tokenizer)
logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)