Fix set mem_id for inference and refactor
This commit is contained in:
@@ -78,6 +78,9 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
if cfg.landmark_attention:
|
||||||
|
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||||
|
|
||||||
|
set_model_mem_id(model, tokenizer)
|
||||||
model.set_mem_cache_args(
|
model.set_mem_cache_args(
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import torch
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
@@ -1237,3 +1238,12 @@ def patch_llama_with_landmark_attn():
|
|||||||
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
def set_model_mem_id(model: LlamaForCausalLM, tokenizer: LlamaTokenizer):
|
||||||
|
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||||
|
model.set_mem_id(mem_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mem_id(tokenizer: LlamaTokenizer):
|
||||||
|
return tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
||||||
|
|||||||
@@ -239,16 +239,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import MEM_TOKEN, add_mem_tokens
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
|
add_mem_tokens,
|
||||||
|
get_mem_id,
|
||||||
|
set_model_mem_id,
|
||||||
|
)
|
||||||
|
|
||||||
mem_id = tokenizer.convert_tokens_to_ids(MEM_TOKEN)
|
set_model_mem_id(model, tokenizer)
|
||||||
model.set_mem_id(mem_id)
|
|
||||||
|
|
||||||
logging.info("Adding landmark attention tokens to dataset")
|
logging.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
for dataset in [train_dataset, eval_dataset]:
|
for dataset in [train_dataset, eval_dataset]:
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
partial(add_mem_tokens, mem_freq=50, mem_id=mem_id),
|
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||||
batched=False,
|
batched=False,
|
||||||
num_proc=32,
|
num_proc=32,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user