fix seq lens calc to drop hanging sequences
This commit is contained in:
@@ -10,9 +10,10 @@ import xformers.ops.fmha
|
||||
from flash_attn.bert_padding import pad_input
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_upad_input,
|
||||
prepare_fa2_from_position_ids,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
||||
|
||||
|
||||
@@ -51,22 +52,14 @@ def xformers_attention_forward(
|
||||
batch_size = query.size(0)
|
||||
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
_, _, _, indices_q, cu_seq_lens, _ = prepare_fa2_from_position_ids(
|
||||
query, key, value, position_ids
|
||||
)
|
||||
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
seq_lengths = []
|
||||
for i in range(len(cu_seq_lens_q) - 1):
|
||||
seq_lengths.append(
|
||||
cu_seq_lens_q[i + 1].item() - cu_seq_lens_q[i].item()
|
||||
)
|
||||
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
|
||||
attn_bias = (
|
||||
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths,
|
||||
q_seqlen=seq_lengths.tolist(),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||
@@ -101,10 +94,10 @@ def xformers_attention_forward(
|
||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||
query, key, value, attention_mask, query_length
|
||||
)
|
||||
cu_seqlens_q, cu_seq_lens_k = cu_seq_lens
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
seq_lengths = []
|
||||
for i in range(len(cu_seq_lens_q) - 1):
|
||||
seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i])
|
||||
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
|
||||
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths,
|
||||
kv_seqlen=seq_lengths,
|
||||
|
||||
Reference in New Issue
Block a user