fix sdp attention to use the flash/mem-efficient context manaager
This commit is contained in:
@@ -184,14 +184,15 @@ def sdp_attention_forward(
|
|||||||
|
|
||||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
with torch.backends.cuda.sdp_kernel():
|
||||||
query_states,
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
key_states,
|
query_states,
|
||||||
value_states,
|
key_states,
|
||||||
attn_mask=attention_mask,
|
value_states,
|
||||||
is_causal=False,
|
attn_mask=attention_mask,
|
||||||
)
|
is_causal=False,
|
||||||
attn_weights = None
|
)
|
||||||
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(
|
attn_weights = torch.matmul(
|
||||||
query_states, key_states.transpose(2, 3)
|
query_states, key_states.transpose(2, 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user