diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index a445c3a5a..cb0aa3fe6 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -158,7 +158,7 @@ def flashattn_forward( else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen - is_causal = past_key_value is not None + is_causal = key_states.shape == query_states.shape if cu_seqlens is not None and max_seqlen is not None: # special handling using sample packing