diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 6e43b1d53..14056fa54 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -11,6 +11,9 @@ import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer as OriginalLlamaDecoderLayer, +) from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids @@ -36,12 +39,10 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): ) transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward if packed: + transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = ( llama_model_forward ) - transformers.models.llama.modeling_llama.LlamaDecoderLayer = ( - llama_decoder_layer_forward - ) # Disable the transformation of the attention mask in LlamaModel as the flash attention @@ -159,7 +160,7 @@ def flashattn_forward( # only on first autoregressive step q,k,v have same seqlen is_causal = past_key_value is not None - if cu_seqlens and max_seqlen: + if cu_seqlens is not None and max_seqlen is not None: # special handling using sample packing qkv = torch.stack( [query_states, key_states, value_states], dim=2 @@ -472,9 +473,9 @@ def llama_model_forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs, **kwargs): + def custom_forward(*inputs): # None for past_key_value - return module(*inputs, output_attentions, None, **kwargs) + return module(*inputs) return custom_forward @@ -484,8 +485,10 @@ def llama_model_forward( attention_mask, position_ids, None, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + output_attentions, + None, + cu_seqlens, + max_seqlen, ) else: layer_outputs = decoder_layer( @@ -528,61 +531,68 @@ def llama_model_forward( ) -def llama_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cu_seqlens: Optional[torch.Tensor] = None, - max_seqlen: Optional[torch.Tensor] = None, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: +class LlamaDecoderLayer(OriginalLlamaDecoderLayer): """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing + patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens """ - residual = hidden_states + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing + """ - hidden_states = self.input_layernorm(hidden_states) + residual = hidden_states - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - ) - hidden_states = residual + hidden_states + hidden_states = self.input_layernorm(hidden_states) - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + hidden_states = residual + hidden_states - outputs = (hidden_states,) + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - if output_attentions: - outputs += (self_attn_weights,) + outputs = (hidden_states,) - if use_cache: - outputs += (present_key_value,) + if output_attentions: + outputs += (self_attn_weights,) - return outputs + if use_cache: + outputs += (present_key_value,) + + return outputs