Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
7771498eae add guassian dropout support 2023-09-25 14:50:39 -04:00

View File

@@ -12,6 +12,8 @@ import torch.nn.functional as F
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from torch import nn
from transformers import LlamaConfig
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
@@ -78,6 +80,19 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
) )
class GaussianDropout(nn.Module):
def __init__(self, p=0.5):
super(GaussianDropout, self).__init__()
if p <= 0 or p >= 1:
raise Exception("p value should accomplish 0 < p < 1")
self.p = p
def forward(self, x):
stddev = (self.p / (1.0 - self.p)) ** 0.5
epsilon = torch.randn_like(x) * stddev
return x * epsilon
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -202,7 +217,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -571,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
""" """
def __init__(self, config: LlamaConfig):
super(LlamaDecoderLayer, self).__init__(config)
self.attn_dropout = None
self.mlp_dropout = None
if config.dropout_attn:
self.attn_dropout = GaussianDropout(p=config.dropout_attn)
if config.dropout_mlp:
self.mlp_dropout = GaussianDropout(p=config.dropout_mlp)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -614,12 +638,16 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
) )
if self.training and self.attn_dropout:
hidden_states = self.attn_dropout(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
if self.training and self.mlp_dropout:
hidden_states = self.mlp_dropout(hidden_states)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
outputs = (hidden_states,) outputs = (hidden_states,)