Fix set mem_id for inference and refactor
This commit is contained in:
@@ -29,6 +29,7 @@ import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import LlamaTokenizer
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
|
||||
|
||||
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
|
||||
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||
model.set_mem_id(mem_id)
|
||||
|
||||
|
||||
def get_mem_id(tokenizer: LlamaTokenizer):
|
||||
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||
|
||||
@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from functools import partial
|
||||
|
||||
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||
model.set_mem_id(mem_id)
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
logging.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user