Compare commits
1 Commits
NanoCode01
...
llama-drop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7771498eae |
@@ -12,6 +12,8 @@ import torch.nn.functional as F
|
|||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||||
@@ -78,6 +80,19 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDropout(nn.Module):
|
||||||
|
def __init__(self, p=0.5):
|
||||||
|
super(GaussianDropout, self).__init__()
|
||||||
|
if p <= 0 or p >= 1:
|
||||||
|
raise Exception("p value should accomplish 0 < p < 1")
|
||||||
|
self.p = p
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
stddev = (self.p / (1.0 - self.p)) ** 0.5
|
||||||
|
epsilon = torch.randn_like(x) * stddev
|
||||||
|
return x * epsilon
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -202,7 +217,7 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -571,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaConfig):
|
||||||
|
super(LlamaDecoderLayer, self).__init__(config)
|
||||||
|
self.attn_dropout = None
|
||||||
|
self.mlp_dropout = None
|
||||||
|
if config.dropout_attn:
|
||||||
|
self.attn_dropout = GaussianDropout(p=config.dropout_attn)
|
||||||
|
if config.dropout_mlp:
|
||||||
|
self.mlp_dropout = GaussianDropout(p=config.dropout_mlp)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -614,12 +638,16 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
if self.training and self.attn_dropout:
|
||||||
|
hidden_states = self.attn_dropout(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
if self.training and self.mlp_dropout:
|
||||||
|
hidden_states = self.mlp_dropout(hidden_states)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|||||||
Reference in New Issue
Block a user