diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index c6bdafb89..8fa00f43b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -184,14 +184,15 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - attn_weights = None + with torch.backends.cuda.sdp_kernel(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) + attn_weights = None else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3)