fixing negative component mixing

This commit is contained in:
Dan Saunders
2025-01-13 19:21:55 +00:00
parent 661d71a14b
commit fd8ad6fcbf
3 changed files with 5 additions and 3 deletions

View File

@@ -6,7 +6,7 @@ from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.differential import (
from axolotl.utils.callbacks.diff_attn import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)

View File

@@ -174,7 +174,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 0.0 # Default to full mixing
self.diff_attn_mix = 1.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
@@ -383,7 +383,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
# NOTE: this may need to be added back in, but doesn't interact well with
# `diff_attn_mix`.
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)