fix check for flash attn branching (#377)

This commit is contained in:
Wing Lian
2023-08-12 22:48:08 -04:00
committed by GitHub
parent 0c967279ce
commit 343ac84e5a

View File

@@ -92,7 +92,7 @@ def forward(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif position_ids.shape[0] == 1:
elif attention_mask.shape[0] == 1:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)