llama sdpa patching WIP - static class function import

This commit is contained in:
Sunny Liu
2025-01-22 20:33:13 -05:00
parent f3bec17917
commit d7b133dc1f

View File

@@ -28,13 +28,11 @@ def hijack_llama_prepare_4d_mask():
# ) # )
def llama_patched_prepare_4d_causal_attention_mask_with_cache_position( def llama_patched_prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor], *args, **kwargs
*args,
): ):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position( return LlamaModel._prepare_4d_causal_attention_mask_with_cache_position(
mask_2d_to_4d(attention_mask, dtype=dtype), mask_2d_to_4d(attention_mask, dtype=dtype), *args, **kwargs
*args,
) )
LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access LlamaModel._prepare_4d_causal_attention_mask_with_cache_position = ( # pylint: disable=protected-access