Refactor landmark attention patch

This commit is contained in:
NanoCode012
2023-06-10 08:09:29 +09:00
parent d9f713e4e3
commit 919727b4d7
2 changed files with 20 additions and 12 deletions

View File

@@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
ret.extend(x[prev_idx:])
# drop attention_mask
return {"input_ids": ret}
def patch_llama_with_landmark_attn():
import transformers
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer

View File

@@ -19,15 +19,6 @@ from transformers import ( # noqa: F401
LlamaConfig,
)
try:
from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
if TYPE_CHECKING:
@@ -118,14 +109,15 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
from axolotl.monkeypatch.llama_landmark_attn import (
MEM_TOKEN,
LlamaForCausalLM,
patch_llama_with_landmark_attn,
)
logging.info("patching with landmark attention")
patch_llama_with_landmark_attn()
# TODO: Check if this would overwrite previous additional_special_tokens
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_llama_derived_model and cfg.xpos_rope:
@@ -211,6 +203,13 @@ def load_model(
)
load_in_8bit = False
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
try:
from transformers import LlamaForCausalLM
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
)
config = LlamaConfig.from_pretrained(base_model_config)
model = LlamaForCausalLM.from_pretrained(
base_model,