diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index b83aa3abf..6ee043d8c 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -384,7 +384,7 @@ class LlamaDifferentialAttentionBase(nn.Module): """ attn = self.subln(attn) # NOTE: this may need to be added back in, but doesn't interact well with - # `diff_attn_mix`. + # `diff_attn_mix`, and doesn't allow us to preserve the original model output. # attn = attn * self.diff_attn_mix * (1 - self.lambda_init) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)