From a032c9f452c7aa2c65078170fec2496a72222f03 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jul 2023 01:05:48 -0400 Subject: [PATCH] fix sdp attention to use the flash/mem-efficient context manaager --- .../monkeypatch/llama_attn_hijack_xformers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index c6bdafb89..8fa00f43b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -184,14 +184,15 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - attn_weights = None + with torch.backends.cuda.sdp_kernel(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) + attn_weights = None else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3)