From fd8ad6fcbf52b2f74549bc7752bcf6db8212b498 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 13 Jan 2025 19:21:55 +0000 Subject: [PATCH] fixing negative component mixing --- src/axolotl/integrations/diff_transformer/__init__.py | 2 +- src/axolotl/integrations/diff_transformer/diff_attn.py | 6 ++++-- .../utils/callbacks/{differential.py => diff_attn.py} | 0 3 files changed, 5 insertions(+), 3 deletions(-) rename src/axolotl/utils/callbacks/{differential.py => diff_attn.py} (100%) diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index fdacd2b3a..3b98ae246 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -6,7 +6,7 @@ from typing import List from transformers import PreTrainedModel, TrainerCallback from axolotl.integrations.base import BasePlugin -from axolotl.utils.callbacks.differential import ( +from axolotl.utils.callbacks.diff_attn import ( DifferentialAttentionMixingCallback, DifferentialAttentionMonitorCallback, ) diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index 3744b0df5..b83aa3abf 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -174,7 +174,7 @@ class LlamaDifferentialAttentionBase(nn.Module): - Initial lambda value based on layer index. - Rotary position embedding layer. """ - self.diff_attn_mix = 0.0 # Default to full mixing + self.diff_attn_mix = 1.0 # Default to full mixing self.lambda_init = nn.Parameter( torch.full((), lambda_init_fn(self.layer_idx)), @@ -383,7 +383,9 @@ class LlamaDifferentialAttentionBase(nn.Module): Processed attention output of shape (batch_size, seq_len, hidden_size) """ attn = self.subln(attn) - attn = attn * self.diff_attn_mix * (1 - self.lambda_init) + # NOTE: this may need to be added back in, but doesn't interact well with + # `diff_attn_mix`. + # attn = attn * self.diff_attn_mix * (1 - self.lambda_init) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) return self.o_proj(attn) diff --git a/src/axolotl/utils/callbacks/differential.py b/src/axolotl/utils/callbacks/diff_attn.py similarity index 100% rename from src/axolotl/utils/callbacks/differential.py rename to src/axolotl/utils/callbacks/diff_attn.py