From f3bec179170845fd6bb38e00423caf7597937d1a Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 20:25:26 -0500 Subject: [PATCH] llama sdpa patching WIP - static class function import --- src/axolotl/monkeypatch/llama_patch_multipack.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index a60eda524..f4779d69d 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -15,9 +15,10 @@ def hijack_llama_prepare_4d_mask(): import torch from transformers import modeling_attn_mask_utils from transformers.models.llama.modeling_llama import LlamaModel - from transformers.models.llama.modeling_llama.LlamaModel import ( - _prepare_4d_causal_attention_mask_with_cache_position, - ) + + # from transformers.models.llama.modeling_llama.LlamaModel import ( + # _prepare_4d_causal_attention_mask_with_cache_position, + # ) from transformers.utils import is_torch_bf16_gpu_available from axolotl.monkeypatch.utils import mask_2d_to_4d @@ -31,7 +32,7 @@ def hijack_llama_prepare_4d_mask(): *args, ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask_with_cache_position( + return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( mask_2d_to_4d(attention_mask, dtype=dtype), *args, )