update forwards so we only calculate cu_seqlens once

This commit is contained in:
Wing Lian
2023-08-17 18:02:41 -04:00
parent a6b737d5ff
commit 7565fb9d63

View File

@@ -65,6 +65,8 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel """Input shape: Batch x Time x Channel
@@ -157,18 +159,16 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = past_key_value is not None is_causal = past_key_value is not None
if self.training and position_ids.shape[0] == 1: if cu_seqlens and max_seqlen:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
[query_states, key_states, value_states], dim=2 [query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd] ) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -415,6 +415,8 @@ def llama_model_forward(
past_key_values_length = past_key_values[0][0].shape[2] past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange( position_ids = torch.arange(
@@ -426,6 +428,8 @@ def llama_model_forward(
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@@ -480,6 +484,8 @@ def llama_model_forward(
attention_mask, attention_mask,
position_ids, position_ids,
None, None,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -489,6 +495,8 @@ def llama_model_forward(
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
@@ -528,6 +536,8 @@ def llama_decoder_layer_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: Args:
@@ -556,6 +566,8 @@ def llama_decoder_layer_forward(
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states