From cee310dcfa6b1571582253e9fe1692af5b0e814c Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Wed, 22 Jan 2025 20:15:23 -0500 Subject: [PATCH] llama sdpa patching WIP --- src/axolotl/monkeypatch/llama_patch_multipack.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index e8e7a0409..c200f75cb 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -2,10 +2,6 @@ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention """ -from typing import Optional - -import torch -from transformers.utils import is_torch_bf16_gpu_available from axolotl.monkeypatch.utils import ( mask_2d_to_4d, @@ -15,11 +11,15 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): + from typing import Optional + + import torch from transformers import modeling_attn_mask_utils from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama.LlamaModel import ( _prepare_4d_causal_attention_mask_with_cache_position, ) + from transformers.utils import is_torch_bf16_gpu_available # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access # patched_prepare_4d_causal_attention_mask_for_sdpa