fix sdp attention to use the flash/mem-efficient context manaager

This commit is contained in:
Wing Lian
2023-07-20 01:05:48 -04:00
parent b06d3e3645
commit a032c9f452

View File

@@ -184,6 +184,7 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix # We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,