diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index 18e913f09..1a130f755 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -1593,3 +1593,12 @@ def add_mem_tokens(example, mem_freq, mem_id): ret.extend(x[prev_idx:]) # drop attention_mask return {"input_ids": ret} + + +def patch_llama_with_landmark_attn(): + import transformers + + transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM + transformers.models.llama.modeling_llama.LlamaModel = LlamaModel + transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fb363952c..b84597076 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -19,15 +19,6 @@ from transformers import ( # noqa: F401 LlamaConfig, ) -try: - from transformers import ( # pylint: disable=unused-import # noqa: F401 - LlamaForCausalLM, - ) -except ImportError: - logging.warning( - "This version of transformers does not support Llama. Consider upgrading." - ) - from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN if TYPE_CHECKING: @@ -118,14 +109,15 @@ def load_model( logging.info("patching with sdp attention") hijack_llama_sdp_attention() elif cfg.is_llama_derived_model and cfg.landmark_attention: - from axolotl.monkeypatch.llama_landmark_attn import ( # pylint: disable=redefined-outer-name # noqa: F811 + from axolotl.monkeypatch.llama_landmark_attn import ( MEM_TOKEN, - LlamaForCausalLM, + patch_llama_with_landmark_attn, ) logging.info("patching with landmark attention") + patch_llama_with_landmark_attn() - # TODO: Check if this would overwrite previous additional_special_tokens + # Note: This might overwrite previous additional_special_tokens tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]}) if cfg.is_llama_derived_model and cfg.xpos_rope: @@ -211,6 +203,13 @@ def load_model( ) load_in_8bit = False elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals(): + try: + from transformers import LlamaForCausalLM + except ImportError: + logging.warning( + "This version of transformers does not support Llama. Consider upgrading." + ) + config = LlamaConfig.from_pretrained(base_model_config) model = LlamaForCausalLM.from_pretrained( base_model,