From 5ca57cb55a20694edca2cce07ae4586aab78a7f5 Mon Sep 17 00:00:00 2001 From: Sunny Liu Date: Thu, 23 Jan 2025 17:56:13 -0500 Subject: [PATCH] undo bool conversion --- src/axolotl/monkeypatch/llama_patch_multipack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py index ad72af385..43b11d918 100644 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ b/src/axolotl/monkeypatch/llama_patch_multipack.py @@ -35,7 +35,7 @@ def hijack_llama_prepare_4d_mask(): # return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( # mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs # ) - return mask_2d_to_4d(attention_mask, dtype=dtype).bool() + return mask_2d_to_4d(attention_mask, dtype=dtype) LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access llama_patched_prepare_4d_causal_attention_mask_with_cache_position