diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d900e897d..6cdd50934 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -92,7 +92,7 @@ def forward( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - elif position_ids.shape[0] == 1: + elif attention_mask.shape[0] == 1: # special handling using sample packing qkv = rearrange(qkv, "b s ... -> (b s) ...") cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)