From 5f9f77f3843fb45731221c84e0181c726e85c70a Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 11:29:28 -0500 Subject: [PATCH] llama patch --- src/axolotl/monkeypatch/llama_patch_multipack.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index f3b78dc3c..29c9bf7a2 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -10,16 +10,16 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): from transformers import modeling_attn_mask_utils - from transformers.models.llama import modeling_llama + from transformers.models.llama.modeling_llama import LlamaModel # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access - modeling_llama._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access + LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + LlamaModel._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access