diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index a60eda524..f4779d69d 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -15,9 +15,10 @@ def hijack_llama_prepare_4d_mask(): import torch from transformers import modeling_attn_mask_utils from transformers.models.llama.modeling_llama import LlamaModel - from transformers.models.llama.modeling_llama.LlamaModel import ( - _prepare_4d_causal_attention_mask_with_cache_position, - ) + + # from transformers.models.llama.modeling_llama.LlamaModel import ( + # _prepare_4d_causal_attention_mask_with_cache_position, + # ) from transformers.utils import is_torch_bf16_gpu_available from axolotl.monkeypatch.utils import mask_2d_to_4d @@ -31,7 +32,7 @@ def hijack_llama_prepare_4d_mask(): *args, ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask_with_cache_position( + return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( mask_2d_to_4d(attention_mask, dtype=dtype), *args, )