fix sdp attention to use the flash/mem-efficient context manaager

This commit is contained in:
Wing Lian
2023-07-20 01:05:48 -04:00
parent b06d3e3645
commit a032c9f452

View File

@@ -184,14 +184,15 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix # We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention( with torch.backends.cuda.sdp_kernel():
query_states, attn_output = torch.nn.functional.scaled_dot_product_attention(
key_states, query_states,
value_states, key_states,
attn_mask=attention_mask, value_states,
is_causal=False, attn_mask=attention_mask,
) is_causal=False,
attn_weights = None )
attn_weights = None
else: else:
attn_weights = torch.matmul( attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3) query_states, key_states.transpose(2, 3)