From 27532825a9086a3e9433ad0f904fdd52258b2a33 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 21:00:34 -0500 Subject: [PATCH] llama sdpa patching WIP - static class function import --- src/axolotl/monkeypatch/llama_patch_multipack.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index bc844af25..5a8112ac1 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -27,12 +27,13 @@ def hijack_llama_prepare_4d_mask(): # patched_prepare_4d_causal_attention_mask_for_sdpa # ) + @staticmethod def llama_patched_prepare_4d_causal_attention_mask_with_cache_position( - self, attention_mask: Optional[torch.Tensor], *args, **kwargs + attention_mask: Optional[torch.Tensor], *args, **kwargs ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( - self, mask_2d_to_4d(attention_mask, dtype=dtype), *args, **kwargs + mask_2d_to_4d(attention_mask, dtype=dtype), *args, **kwargs ) LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access