diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 600c5ad54..c4f1fe947 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -7,7 +7,6 @@ from typing import Optional, Tuple import torch import transformers from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input try: from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func @@ -19,6 +18,34 @@ except ImportError: from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +def get_cu_seqlens(attn_mask): + device = attn_mask.device + # Exclude zeros to avoid adding their positions to the mask + t_non_zeros = attn_mask[attn_mask != 0] + # Find where the sequence number changes (including the first position) + seq_change = torch.cat( + [ + torch.tensor([1], dtype=torch.int32, device=device), + t_non_zeros[1:] != t_non_zeros[:-1], + ] + ) + # Get the indices where the sequence changes + change_indices = torch.cat( + [ + (seq_change == 1).nonzero(as_tuple=True)[0], + torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), + ] + ) + # Calculate the sequence lengths + seq_lengths = change_indices[1:] - change_indices[:-1] + # Calculate the cumulative sequence lengths + cu_seqlens = torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] + ) + + return cu_seqlens.to(dtype=torch.int32) + + def forward( self, hidden_states: torch.Tensor, @@ -91,35 +118,15 @@ def forward( ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) else: - nheads = qkv.shape[-2] + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = q_len + cu_q_lens = get_cu_seqlens(key_padding_mask) - # pylint: disable=invalid-name - x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) - x_unpad = rearrange( - x_unpad, - "nnz (three h d) -> nnz three h d", - three=3, - h=nheads, - ) - output_unpad = flash_attn_varlen_qkvpacked_func( - x_unpad, - cu_q_lens, - max_s, - 0.0, - softmax_scale=None, - causal=True, - ) - output = rearrange( - pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), - indices, - bsz, - q_len, - ), - "b s (h d) -> b s h d", - h=nheads, + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + return ( self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None,