fix the patch to work properly and work with FSDP
This commit is contained in:
@@ -11,6 +11,9 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
|
)
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
@@ -36,12 +39,10 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
if packed:
|
if packed:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
|
|
||||||
llama_decoder_layer_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -159,7 +160,7 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = past_key_value is not None
|
is_causal = past_key_value is not None
|
||||||
|
|
||||||
if cu_seqlens and max_seqlen:
|
if cu_seqlens is not None and max_seqlen is not None:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
[query_states, key_states, value_states], dim=2
|
[query_states, key_states, value_states], dim=2
|
||||||
@@ -472,9 +473,9 @@ def llama_model_forward(
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs, **kwargs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None, **kwargs)
|
return module(*inputs)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -484,8 +485,10 @@ def llama_model_forward(
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
None,
|
||||||
cu_seqlens=cu_seqlens,
|
output_attentions,
|
||||||
max_seqlen=max_seqlen,
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@@ -528,61 +531,68 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_decoder_layer_forward(
|
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: Optional[bool] = False,
|
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
||||||
"""
|
"""
|
||||||
Args:
|
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
||||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
||||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
||||||
returned tensors for more detail.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
||||||
(see `past_key_values`).
|
|
||||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
||||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
residual = hidden_states
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||||
|
"""
|
||||||
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
residual = hidden_states
|
||||||
|
|
||||||
# Self Attention
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_value=past_key_value,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
use_cache=use_cache,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
max_seqlen=max_seqlen,
|
|
||||||
)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
|
|
||||||
# Fully Connected
|
# Self Attention
|
||||||
residual = hidden_states
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states=hidden_states,
|
||||||
hidden_states = self.mlp(hidden_states)
|
attention_mask=attention_mask,
|
||||||
hidden_states = residual + hidden_states
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
if output_attentions:
|
outputs = (hidden_states,)
|
||||||
outputs += (self_attn_weights,)
|
|
||||||
|
|
||||||
if use_cache:
|
if output_attentions:
|
||||||
outputs += (present_key_value,)
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
return outputs
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user