Compare commits
1 Commits
mora
...
llama-drop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7771498eae |
@@ -12,6 +12,8 @@ import torch.nn.functional as F
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
@@ -78,6 +80,19 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
)
|
||||
|
||||
|
||||
class GaussianDropout(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
super(GaussianDropout, self).__init__()
|
||||
if p <= 0 or p >= 1:
|
||||
raise Exception("p value should accomplish 0 < p < 1")
|
||||
self.p = p
|
||||
|
||||
def forward(self, x):
|
||||
stddev = (self.p / (1.0 - self.p)) ** 0.5
|
||||
epsilon = torch.randn_like(x) * stddev
|
||||
return x * epsilon
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -202,7 +217,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -571,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super(LlamaDecoderLayer, self).__init__(config)
|
||||
self.attn_dropout = None
|
||||
self.mlp_dropout = None
|
||||
if config.dropout_attn:
|
||||
self.attn_dropout = GaussianDropout(p=config.dropout_attn)
|
||||
if config.dropout_mlp:
|
||||
self.mlp_dropout = GaussianDropout(p=config.dropout_mlp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -614,12 +638,16 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if self.training and self.attn_dropout:
|
||||
hidden_states = self.attn_dropout(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.training and self.mlp_dropout:
|
||||
hidden_states = self.mlp_dropout(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
Reference in New Issue
Block a user