diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 998081900..5fcbcb55a 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -41,32 +41,10 @@ def xformers_attention_forward( attn_bias = xformers.ops.LowerTriangularMask() - if attention_mask is not None: - batch_size = query.shape[0] - query, key, value, indices_q, cu_seq_lens, _ = _upad_input( - query, key, value, attention_mask, query_length - ) - cu_seqlens_q, cu_seq_lens_k = cu_seq_lens - seq_lengths = [] - for i in range(len(cu_seq_lens_q) - 1): - seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i]) - attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( - q_seqlen=seq_lengths, - kv_seqlen=seq_lengths, - ) - - attn_output_unpad = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and ( + if position_ids is not None and ( max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) ): @@ -118,7 +96,27 @@ def xformers_attention_forward( value, attn_bias=attn_bias, ) + elif attention_mask is not None: + batch_size = query.shape[0] + query, key, value, indices_q, cu_seq_lens, _ = _upad_input( + query, key, value, attention_mask, query_length + ) + cu_seqlens_q, cu_seq_lens_k = cu_seq_lens + seq_lengths = [] + for i in range(len(cu_seq_lens_q) - 1): + seq_lengths.append(cu_seqlens_q[i + 1] - cu_seq_lens_q[i]) + attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=seq_lengths, + kv_seqlen=seq_lengths, + ) + attn_output_unpad = xformers_attention( + query, + key, + value, + attn_bias=attn_bias, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: attn_output = xformers_attention( query,