diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index fdacd2b3a..3b98ae246 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -6,7 +6,7 @@ from typing import List from transformers import PreTrainedModel, TrainerCallback from axolotl.integrations.base import BasePlugin -from axolotl.utils.callbacks.differential import ( +from axolotl.utils.callbacks.diff_attn import ( DifferentialAttentionMixingCallback, DifferentialAttentionMonitorCallback, ) diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index 3744b0df5..b83aa3abf 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -174,7 +174,7 @@ class LlamaDifferentialAttentionBase(nn.Module): - Initial lambda value based on layer index. - Rotary position embedding layer. """ - self.diff_attn_mix = 0.0 # Default to full mixing + self.diff_attn_mix = 1.0 # Default to full mixing self.lambda_init = nn.Parameter( torch.full((), lambda_init_fn(self.layer_idx)), @@ -383,7 +383,9 @@ class LlamaDifferentialAttentionBase(nn.Module): Processed attention output of shape (batch_size, seq_len, hidden_size) """ attn = self.subln(attn) - attn = attn * self.diff_attn_mix * (1 - self.lambda_init) + # NOTE: this may need to be added back in, but doesn't interact well with + # `diff_attn_mix`. + # attn = attn * self.diff_attn_mix * (1 - self.lambda_init) attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) return self.o_proj(attn) diff --git a/src/axolotl/utils/callbacks/differential.py b/src/axolotl/utils/callbacks/diff_attn.py similarity index 100% rename from src/axolotl/utils/callbacks/differential.py rename to src/axolotl/utils/callbacks/diff_attn.py