fix sdp attention to use the flash/mem-efficient context manaager

This commit is contained in:
Wing Lian
2023-07-20 01:05:48 -04:00
parent b06d3e3645
commit a032c9f452

View File

@@ -184,14 +184,15 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)