llama sdpa patching WIP - static class function import
This commit is contained in:
@@ -28,13 +28,11 @@ def hijack_llama_prepare_4d_mask():
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
def llama_patched_prepare_4d_causal_attention_mask_with_cache_position(
|
def llama_patched_prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
attention_mask: Optional[torch.Tensor],
|
attention_mask: Optional[torch.Tensor], *args, **kwargs
|
||||||
*args,
|
|
||||||
):
|
):
|
||||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||||
return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
mask_2d_to_4d(attention_mask, dtype=dtype), *args, **kwargs
|
||||||
*args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
|
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access
|
||||||
|
|||||||
Reference in New Issue
Block a user