From e44c9e0b3e40c6b46f4617d60cbad68d23d32e10 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Fri, 9 Jun 2023 14:27:24 +0900 Subject: [PATCH] Fix patching via import instead of hijacking --- src/axolotl/utils/models.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 3a806c3b6..bbb72446a 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -20,7 +20,9 @@ from transformers import ( # noqa: F401 ) try: - from transformers import LlamaForCausalLM + from transformers import ( # pylint: disable=unused-import # noqa: F401 + LlamaForCausalLM, + ) except ImportError: logging.warning( "This version of transformers does not support Llama. Consider upgrading." @@ -115,15 +117,15 @@ def load_model( logging.info("patching with sdp attention") hijack_llama_sdp_attention() elif cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( + from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811 MEM_TOKEN, - hijack_llama_landmark_attn, + LlamaForCausalLM, ) logging.info("patching with landmark attention") - hijack_llama_landmark_attn() - tokenizer.add_special_tokens({"mem_token": MEM_TOKEN}) + # TODO: Check if this would overwrite previous additional_special_tokens + tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) if cfg.bf16: torch_dtype = torch.bfloat16