Refactor landmark attention patch
This commit is contained in:
@@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id):
|
|||||||
ret.extend(x[prev_idx:])
|
ret.extend(x[prev_idx:])
|
||||||
# drop attention_mask
|
# drop attention_mask
|
||||||
return {"input_ids": ret}
|
return {"input_ids": ret}
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_with_landmark_attn():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
|
|||||||
@@ -19,15 +19,6 @@ from transformers import ( # noqa: F401
|
|||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
logging.warning(
|
|
||||||
"This version of transformers does not support Llama. Consider upgrading."
|
|
||||||
)
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -118,14 +109,15 @@ def load_model(
|
|||||||
logging.info("patching with sdp attention")
|
logging.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
MEM_TOKEN,
|
MEM_TOKEN,
|
||||||
LlamaForCausalLM,
|
patch_llama_with_landmark_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with landmark attention")
|
logging.info("patching with landmark attention")
|
||||||
|
patch_llama_with_landmark_attn()
|
||||||
|
|
||||||
# TODO: Check if this would overwrite previous additional_special_tokens
|
# Note: This might overwrite previous additional_special_tokens
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
@@ -211,6 +203,13 @@ def load_model(
|
|||||||
)
|
)
|
||||||
load_in_8bit = False
|
load_in_8bit = False
|
||||||
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
|
||||||
|
try:
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
logging.warning(
|
||||||
|
"This version of transformers does not support Llama. Consider upgrading."
|
||||||
|
)
|
||||||
|
|
||||||
config = LlamaConfig.from_pretrained(base_model_config)
|
config = LlamaConfig.from_pretrained(base_model_config)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user