fix sdp attention to use the flash/mem-efficient context manaager
This commit is contained in:
@@ -184,6 +184,7 @@ def sdp_attention_forward(
|
|||||||
|
|
||||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
|
with torch.backends.cuda.sdp_kernel():
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
|||||||
Reference in New Issue
Block a user