From fbf49a4770fcd154e7592333656eb25225e1111d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 21 Aug 2023 10:36:26 -0400 Subject: [PATCH 1/2] is_causal fix for evals? --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index a445c3a5a..79199e34c 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -155,6 +155,8 @@ def flashattn_forward( # during training q,k,v always have same seqlen assert key_states.shape == query_states.shape is_causal = True + elif past_key_value is None: + is_causal = True else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen From a213d9972aeca545e9176c45f0a0bdab04ace277 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 21 Aug 2023 10:40:06 -0700 Subject: [PATCH 2/2] fix eval regression caused in 13f7efaf74fcd3c4514277ccb71914c589873f6a --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 79199e34c..cb0aa3fe6 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -155,12 +155,10 @@ def flashattn_forward( # during training q,k,v always have same seqlen assert key_states.shape == query_states.shape is_causal = True - elif past_key_value is None: - is_causal = True else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen - is_causal = past_key_value is not None + is_causal = key_states.shape == query_states.shape if cu_seqlens is not None and max_seqlen is not None: # special handling using sample packing