diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index f3b78dc3c..29c9bf7a2 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -10,16 +10,16 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): from transformers import modeling_attn_mask_utils - from transformers.models.llama import modeling_llama + from transformers.models.llama.modeling_llama import LlamaModel # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access - modeling_llama._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access + LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + LlamaModel._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access