fix the patch to work properly and work with FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-08-17 20:33:56 -04:00
parent 6c306d9186
commit 587dbbfc02

View File

@@ -11,6 +11,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -36,12 +39,10 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
) )
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if packed: if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = ( transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward llama_model_forward
) )
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
llama_decoder_layer_forward
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -159,7 +160,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = past_key_value is not None is_causal = past_key_value is not None
if cu_seqlens and max_seqlen: if cu_seqlens is not None and max_seqlen is not None:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
[query_states, key_states, value_states], dim=2 [query_states, key_states, value_states], dim=2
@@ -472,9 +473,9 @@ def llama_model_forward(
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs, **kwargs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None, **kwargs) return module(*inputs)
return custom_forward return custom_forward
@@ -484,8 +485,10 @@ def llama_model_forward(
attention_mask, attention_mask,
position_ids, position_ids,
None, None,
cu_seqlens=cu_seqlens, output_attentions,
max_seqlen=max_seqlen, None,
cu_seqlens,
max_seqlen,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -528,61 +531,68 @@ def llama_model_forward(
) )
def llama_decoder_layer_forward( class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
""" """
Args: patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
""" """
residual = hidden_states def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
# Self Attention hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected # Self Attention
residual = hidden_states hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states=hidden_states,
hidden_states = self.mlp(hidden_states) attention_mask=attention_mask,
hidden_states = residual + hidden_states position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
outputs = (hidden_states,) # Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
if output_attentions: outputs = (hidden_states,)
outputs += (self_attn_weights,)
if use_cache: if output_attentions:
outputs += (present_key_value,) outputs += (self_attn_weights,)
return outputs if use_cache:
outputs += (present_key_value,)
return outputs