fix the patch to work properly and work with FSDP
This commit is contained in:
@@ -11,6 +11,9 @@ import transformers
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
|
)
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
@@ -36,12 +39,10 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
if packed:
|
if packed:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
|
|
||||||
llama_decoder_layer_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -159,7 +160,7 @@ def flashattn_forward(
|
|||||||
# only on first autoregressive step q,k,v have same seqlen
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
is_causal = past_key_value is not None
|
is_causal = past_key_value is not None
|
||||||
|
|
||||||
if cu_seqlens and max_seqlen:
|
if cu_seqlens is not None and max_seqlen is not None:
|
||||||
# special handling using sample packing
|
# special handling using sample packing
|
||||||
qkv = torch.stack(
|
qkv = torch.stack(
|
||||||
[query_states, key_states, value_states], dim=2
|
[query_states, key_states, value_states], dim=2
|
||||||
@@ -472,9 +473,9 @@ def llama_model_forward(
|
|||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs, **kwargs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None, **kwargs)
|
return module(*inputs)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -484,8 +485,10 @@ def llama_model_forward(
|
|||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
None,
|
||||||
cu_seqlens=cu_seqlens,
|
output_attentions,
|
||||||
max_seqlen=max_seqlen,
|
None,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
@@ -528,7 +531,12 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def llama_decoder_layer_forward(
|
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||||
|
"""
|
||||||
|
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -538,7 +546,9 @@ def llama_decoder_layer_forward(
|
|||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
|||||||
Reference in New Issue
Block a user