From 8c34c651812305f95995e75bbd778e5602cba9bf Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 23 Jan 2025 14:56:26 -0500 Subject: [PATCH] dummy --- src/axolotl/monkeypatch/llama_patch_multipack.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index b286175bf..43b11d918 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -32,9 +32,10 @@ def hijack_llama_prepare_4d_mask(): attention_mask: Optional[torch.Tensor], **kwargs ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( - mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs - ) + # return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( + # mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs + # ) + return mask_2d_to_4d(attention_mask, dtype=dtype) LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access llama_patched_prepare_4d_causal_attention_mask_with_cache_position