diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 5fcbcb55a..28eab6ec1 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -10,9 +10,10 @@ import xformers.ops.fmha from flash_attn.bert_padding import pad_input from transformers.modeling_flash_attention_utils import ( _upad_input, - prepare_fa2_from_position_ids, ) +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids + xformers_attention = xformers.ops.fmha.memory_efficient_attention @@ -51,22 +52,14 @@ def xformers_attention_forward( batch_size = query.size(0) if cu_seq_lens_q is None or cu_seq_lens_k is None: - _, _, _, indices_q, cu_seq_lens, _ = prepare_fa2_from_position_ids( - query, key, value, position_ids - ) - - cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens - seq_lengths = [] - for i in range(len(cu_seq_lens_q) - 1): - seq_lengths.append( - cu_seq_lens_q[i + 1].item() - cu_seq_lens_q[i].item() - ) + cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0] + cu_seq_lens_q = cu_seq_lens_q.squeeze() + seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] attn_bias = ( xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( - q_seqlen=seq_lengths, + q_seqlen=seq_lengths.tolist(), ) ) - else: query = query.reshape(-1, query.size(-2), query.size(-1)) key = key.reshape(-1, key.size(-2), key.size(-1)) @@ -101,10 +94,10 @@ def xformers_attention_forward( query, key, value, indices_q, cu_seq_lens, _ = _upad_input( query, key, value, attention_mask, query_length ) - cu_seqlens_q, cu_seq_lens_k = cu_seq_lens + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens seq_lengths = [] for i in range(len(cu_seq_lens_q) - 1): - seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i]) + seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i]) attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( q_seqlen=seq_lengths, kv_seqlen=seq_lengths,