reorder the packing check
This commit is contained in:
@@ -41,32 +41,10 @@ def xformers_attention_forward(
|
|||||||
|
|
||||||
attn_bias = xformers.ops.LowerTriangularMask()
|
attn_bias = xformers.ops.LowerTriangularMask()
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
batch_size = query.shape[0]
|
|
||||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
|
||||||
query, key, value, attention_mask, query_length
|
|
||||||
)
|
|
||||||
cu_seqlens_q, cu_seq_lens_k = cu_seq_lens
|
|
||||||
seq_lengths = []
|
|
||||||
for i in range(len(cu_seq_lens_q) - 1):
|
|
||||||
seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i])
|
|
||||||
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
|
||||||
q_seqlen=seq_lengths,
|
|
||||||
kv_seqlen=seq_lengths,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output_unpad = xformers_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_bias=attn_bias,
|
|
||||||
)
|
|
||||||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
|
||||||
|
|
||||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||||
elif position_ids is not None and (
|
if position_ids is not None and (
|
||||||
max_length_q is not None
|
max_length_q is not None
|
||||||
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||||
):
|
):
|
||||||
@@ -118,7 +96,27 @@ def xformers_attention_forward(
|
|||||||
value,
|
value,
|
||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
)
|
)
|
||||||
|
elif attention_mask is not None:
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||||
|
query, key, value, attention_mask, query_length
|
||||||
|
)
|
||||||
|
cu_seqlens_q, cu_seq_lens_k = cu_seq_lens
|
||||||
|
seq_lengths = []
|
||||||
|
for i in range(len(cu_seq_lens_q) - 1):
|
||||||
|
seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i])
|
||||||
|
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
|
q_seqlen=seq_lengths,
|
||||||
|
kv_seqlen=seq_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_unpad = xformers_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_bias=attn_bias,
|
||||||
|
)
|
||||||
|
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
||||||
else:
|
else:
|
||||||
attn_output = xformers_attention(
|
attn_output = xformers_attention(
|
||||||
query,
|
query,
|
||||||
|
|||||||
Reference in New Issue
Block a user