From 7565fb9d63b2a3a238aa8dfe458d63468ed6a84a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 17 Aug 2023 18:02:41 -0400 Subject: [PATCH] update forwards so we only calculate cu_seqlens once --- .../monkeypatch/llama_attn_hijack_flash.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 5ef05ace7..6e43b1d53 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -65,6 +65,8 @@ def flashattn_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel @@ -157,18 +159,16 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = past_key_value is not None - if self.training and position_ids.shape[0] == 1: + if cu_seqlens and max_seqlen: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 ) # [bsz, nh, 3, q_len, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] qkv = rearrange(qkv, "b s ... -> (b s) ...") - cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) - cu_q_lens = cu_q_lens.squeeze() output = flash_attn_varlen_qkvpacked_func( - qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal + qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: @@ -415,6 +415,8 @@ def llama_model_forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length + cu_seqlens = None + max_seqlen = None if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( @@ -426,6 +428,8 @@ def llama_model_forward( position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() + cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids) + cu_seqlens = cu_seqlens.squeeze() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -480,6 +484,8 @@ def llama_model_forward( attention_mask, position_ids, None, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) else: layer_outputs = decoder_layer( @@ -489,6 +495,8 @@ def llama_model_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) hidden_states = layer_outputs[0] @@ -528,6 +536,8 @@ def llama_decoder_layer_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -556,6 +566,8 @@ def llama_decoder_layer_forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, ) hidden_states = residual + hidden_states