fix seq lens calc to drop hanging sequences

This commit is contained in:
Wing Lian
2025-05-03 21:12:25 -04:00
parent 372fd08548
commit c7f38ba96b

View File

@@ -10,9 +10,10 @@ import xformers.ops.fmha
from flash_attn.bert_padding import pad_input
from transformers.modeling_flash_attention_utils import (
_upad_input,
prepare_fa2_from_position_ids,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
@@ -51,22 +52,14 @@ def xformers_attention_forward(
batch_size = query.size(0)
if cu_seq_lens_q is None or cu_seq_lens_k is None:
_, _, _, indices_q, cu_seq_lens, _ = prepare_fa2_from_position_ids(
query, key, value, position_ids
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(
cu_seq_lens_q[i + 1].item() - cu_seq_lens_q[i].item()
)
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
@@ -101,10 +94,10 @@ def xformers_attention_forward(
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seq_lens_k = cu_seq_lens
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i])
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,