From a5360c172c31cd1ce7f4ccdc15ff4198caece305 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Fri, 17 Jan 2025 15:54:03 -0500 Subject: [PATCH] llama hijacking --- src/axolotl/monkeypatch/llama_patch_multipack.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index cfd525367..f3b78dc3c 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -12,7 +12,8 @@ def hijack_llama_prepare_4d_mask(): from transformers import modeling_attn_mask_utils from transformers.models.llama import modeling_llama - modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access