Fix patching via import instead of hijacking

This commit is contained in:
NanoCode012
2023-06-09 14:27:24 +09:00
parent 55b8542de8
commit e44c9e0b3e

View File

@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
)
try:
from transformers import LlamaForCausalLM
from transformers import ( # pylint: disable=unused-import # noqa: F401
LlamaForCausalLM,
)
except ImportError:
logging.warning(
"This version of transformers does not support Llama. Consider upgrading."
@@ -115,15 +117,15 @@ def load_model(
logging.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
MEM_TOKEN,
hijack_llama_landmark_attn,
LlamaForCausalLM,
)
logging.info("patching with landmark attention")
hijack_llama_landmark_attn()
tokenizer.add_special_tokens({"mem_token": MEM_TOKEN})
# TODO: Check if this would overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.bf16:
torch_dtype = torch.bfloat16