From b7deb5241c68208891e6e9ca334849ba0e474a5e Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 20:16:27 -0500 Subject: [PATCH] llama sdpa patching WIP --- src/axolotl/monkeypatch/llama_patch_multipack.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index c200f75cb..a60eda524 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -4,7 +4,6 @@ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention from axolotl.monkeypatch.utils import ( - mask_2d_to_4d, patched_prepare_4d_causal_attention_mask, patched_prepare_4d_causal_attention_mask_for_sdpa, ) @@ -21,6 +20,8 @@ def hijack_llama_prepare_4d_mask(): ) from transformers.utils import is_torch_bf16_gpu_available + from axolotl.monkeypatch.utils import mask_2d_to_4d + # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access # patched_prepare_4d_causal_attention_mask_for_sdpa # )