From 343ac84e5ae4f907e0dc1f67207d708a3b5fa3ab Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 12 Aug 2023 22:48:08 -0400 Subject: [PATCH] fix check for flash attn branching (#377) --- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d900e897d..6cdd50934 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -92,7 +92,7 @@ def forward( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif position_ids.shape[0] == 1: + elif attention_mask.shape[0] == 1: # special handling using sample packing qkv = rearrange(qkv, "b s ... -> (b s) ...") cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)