fix check for flash attn branching (#377)
This commit is contained in:
@@ -92,7 +92,7 @@ def forward(
|
|||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif position_ids.shape[0] == 1:
|
elif attention_mask.shape[0] == 1:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user