undo bool conversion

This commit is contained in:
Sunny Liu
2025-01-23 17:56:13 -05:00
parent 0149de7fb0
commit 5ca57cb55a

View File

@@ -35,7 +35,7 @@ def hijack_llama_prepare_4d_mask():
# return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
# mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
# )
return mask_2d_to_4d(attention_mask, dtype=dtype).bool()
return mask_2d_to_4d(attention_mask, dtype=dtype)
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
llama_patched_prepare_4d_causal_attention_mask_with_cache_position