Fix patching via import instead of hijacking
This commit is contained in:
@@ -20,7 +20,9 @@ from transformers import ( # noqa: F401
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers import ( # pylint: disable=unused-import # noqa: F401
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"This version of transformers does not support Llama. Consider upgrading."
|
||||
@@ -115,15 +117,15 @@ def load_model(
|
||||
logging.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811
|
||||
MEM_TOKEN,
|
||||
hijack_llama_landmark_attn,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
|
||||
logging.info("patching with landmark attention")
|
||||
hijack_llama_landmark_attn()
|
||||
|
||||
tokenizer.add_special_tokens({"mem_token": MEM_TOKEN})
|
||||
# TODO: Check if this would overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.bf16:
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user