diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index d172d302d..7a892802d 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -12,6 +12,8 @@ import torch.nn.functional as F import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input +from torch import nn +from transformers import LlamaConfig from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer as OriginalLlamaDecoderLayer, @@ -78,6 +80,19 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False): ) +class GaussianDropout(nn.Module): + def __init__(self, p=0.5): + super(GaussianDropout, self).__init__() + if p <= 0 or p >= 1: + raise Exception("p value should accomplish 0 < p < 1") + self.p = p + + def forward(self, x): + stddev = (self.p / (1.0 - self.p)) ** 0.5 + epsilon = torch.randn_like(x) * stddev + return x * epsilon + + # Disable the transformation of the attention mask in LlamaModel as the flash attention # requires the attention mask to be the same as the key_padding_mask def _prepare_decoder_attention_mask( @@ -202,7 +217,7 @@ def flashattn_forward( qkv = rearrange(qkv, "b s ... -> (b s) ...") output = flash_attn_varlen_qkvpacked_func( - qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True + qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) elif query_states.shape == key_states.shape: @@ -571,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer): patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens """ + def __init__(self, config: LlamaConfig): + super(LlamaDecoderLayer, self).__init__(config) + self.attn_dropout = None + self.mlp_dropout = None + if config.dropout_attn: + self.attn_dropout = GaussianDropout(p=config.dropout_attn) + if config.dropout_mlp: + self.mlp_dropout = GaussianDropout(p=config.dropout_mlp) + def forward( self, hidden_states: torch.Tensor, @@ -614,12 +638,16 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer): cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) + if self.training and self.attn_dropout: + hidden_states = self.attn_dropout(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + if self.training and self.mlp_dropout: + hidden_states = self.mlp_dropout(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,)