diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index a445c3a5a..79199e34c 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -155,6 +155,8 @@ def flashattn_forward( # during training q,k,v always have same seqlen assert key_states.shape == query_states.shape is_causal = True + elif past_key_value is None: + is_causal = True else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen