dummy
This commit is contained in:
@@ -32,9 +32,10 @@ def hijack_llama_prepare_4d_mask():
|
||||
attention_mask: Optional[torch.Tensor], **kwargs
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
|
||||
)
|
||||
# return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
# mask_2d_to_4d(attention_mask, dtype=dtype), **kwargs
|
||||
# )
|
||||
return mask_2d_to_4d(attention_mask, dtype=dtype)
|
||||
|
||||
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
|
||||
llama_patched_prepare_4d_causal_attention_mask_with_cache_position
|
||||
|
||||
Reference in New Issue
Block a user