undo bool conversion
This commit is contained in:
@@ -35,7 +35,7 @@ def hijack_llama_prepare_4d_mask():
|
|||||||
# return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
# return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
# mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
|
# mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
|
||||||
# )
|
# )
|
||||||
return mask_2d_to_4d(attention_mask, dtype=dtype).bool()
|
return mask_2d_to_4d(attention_mask, dtype=dtype)
|
||||||
|
|
||||||
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
|
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
|
||||||
llama_patched_prepare_4d_causal_attention_mask_with_cache_position
|
llama_patched_prepare_4d_causal_attention_mask_with_cache_position
|
||||||
|
|||||||
Reference in New Issue
Block a user