From 84960003ed5408a5cc09eb9cf134c5693ff49734 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 30 Jan 2025 14:40:18 -0500 Subject: [PATCH] reset llama_patch_multipack.py --- .../monkeypatch/llama_patch_multipack.py | 33 +++---------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index 43b11d918..cfd525367 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -2,7 +2,6 @@ Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention """ - from axolotl.monkeypatch.utils import ( patched_prepare_4d_causal_attention_mask, patched_prepare_4d_causal_attention_mask_for_sdpa, @@ -10,40 +9,16 @@ from axolotl.monkeypatch.utils import ( def hijack_llama_prepare_4d_mask(): - from typing import Optional - - import torch from transformers import modeling_attn_mask_utils - from transformers.models.llama.modeling_llama import LlamaModel + from transformers.models.llama import modeling_llama - # from transformers.models.llama.modeling_llama.LlamaModel import ( - # _prepare_4d_causal_attention_mask_with_cache_position, - # ) - from transformers.utils import is_torch_bf16_gpu_available - - from axolotl.monkeypatch.utils import mask_2d_to_4d - - # modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access - # patched_prepare_4d_causal_attention_mask_for_sdpa - # ) - - @staticmethod - def llama_patched_prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: Optional[torch.Tensor], **kwargs - ): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - # return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( - # mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs - # ) - return mask_2d_to_4d(attention_mask, dtype=dtype) - - LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access - llama_patched_prepare_4d_causal_attention_mask_with_cache_position + modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access + patched_prepare_4d_causal_attention_mask_for_sdpa ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) - LlamaModel._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access + modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access