fix seq lens calc to drop hanging sequences
This commit is contained in:
@@ -10,9 +10,10 @@ import xformers.ops.fmha
|
|||||||
from flash_attn.bert_padding import pad_input
|
from flash_attn.bert_padding import pad_input
|
||||||
from transformers.modeling_flash_attention_utils import (
|
from transformers.modeling_flash_attention_utils import (
|
||||||
_upad_input,
|
_upad_input,
|
||||||
prepare_fa2_from_position_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
||||||
|
|
||||||
|
|
||||||
@@ -51,22 +52,14 @@ def xformers_attention_forward(
|
|||||||
batch_size = query.size(0)
|
batch_size = query.size(0)
|
||||||
|
|
||||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||||
_, _, _, indices_q, cu_seq_lens, _ = prepare_fa2_from_position_ids(
|
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||||
query, key, value, position_ids
|
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||||
)
|
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
|
||||||
|
|
||||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
|
||||||
seq_lengths = []
|
|
||||||
for i in range(len(cu_seq_lens_q) - 1):
|
|
||||||
seq_lengths.append(
|
|
||||||
cu_seq_lens_q[i + 1].item() - cu_seq_lens_q[i].item()
|
|
||||||
)
|
|
||||||
attn_bias = (
|
attn_bias = (
|
||||||
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
q_seqlen=seq_lengths,
|
q_seqlen=seq_lengths.tolist(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query = query.reshape(-1, query.size(-2), query.size(-1))
|
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||||
@@ -101,10 +94,10 @@ def xformers_attention_forward(
|
|||||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||||
query, key, value, attention_mask, query_length
|
query, key, value, attention_mask, query_length
|
||||||
)
|
)
|
||||||
cu_seqlens_q, cu_seq_lens_k = cu_seq_lens
|
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||||
seq_lengths = []
|
seq_lengths = []
|
||||||
for i in range(len(cu_seq_lens_q) - 1):
|
for i in range(len(cu_seq_lens_q) - 1):
|
||||||
seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i])
|
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
|
||||||
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||||
q_seqlen=seq_lengths,
|
q_seqlen=seq_lengths,
|
||||||
kv_seqlen=seq_lengths,
|
kv_seqlen=seq_lengths,
|
||||||
|
|||||||
Reference in New Issue
Block a user