Fix set mem_id for inference and refactor

This commit is contained in:
NanoCode012
2023-06-11 14:00:54 +09:00
parent 572d1141e6
commit 974dc00a7d
3 changed files with 20 additions and 4 deletions

View File

@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
) )
if cfg.landmark_attention: if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args( model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
) )

View File

@@ -29,6 +29,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers import LlamaTokenizer
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
model.set_mem_id(mem_id)
def get_mem_id(tokenizer: LlamaTokenizer):
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)

View File

@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.is_llama_derived_model and cfg.landmark_attention: if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN) set_model_mem_id(model, tokenizer)
model.set_mem_id(mem_id)
logging.info("Adding landmark attention tokens to dataset") logging.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]: for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map( dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id), partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False, batched=False,
num_proc=32, num_proc=32,
) )