fix the patch to work properly and work with FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled

This commit is contained in:
Wing Lian
2023-08-17 20:33:56 -04:00
parent 6c306d9186
commit 587dbbfc02

View File

@@ -11,6 +11,9 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -36,12 +39,10 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
) )
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if packed: if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = ( transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward llama_model_forward
) )
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
llama_decoder_layer_forward
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -159,7 +160,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = past_key_value is not None is_causal = past_key_value is not None
if cu_seqlens and max_seqlen: if cu_seqlens is not None and max_seqlen is not None:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
[query_states, key_states, value_states], dim=2 [query_states, key_states, value_states], dim=2
@@ -472,9 +473,9 @@ def llama_model_forward(
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs, **kwargs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None, **kwargs) return module(*inputs)
return custom_forward return custom_forward
@@ -484,8 +485,10 @@ def llama_model_forward(
attention_mask, attention_mask,
position_ids, position_ids,
None, None,
cu_seqlens=cu_seqlens, output_attentions,
max_seqlen=max_seqlen, None,
cu_seqlens,
max_seqlen,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
@@ -528,7 +531,12 @@ def llama_model_forward(
) )
def llama_decoder_layer_forward( class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@@ -538,7 +546,9 @@ def llama_decoder_layer_forward(
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None, max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`