fixing negative component mixing
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import List
|
||||
from transformers import PreTrainedModel, TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.callbacks.differential import (
|
||||
from axolotl.utils.callbacks.diff_attn import (
|
||||
DifferentialAttentionMixingCallback,
|
||||
DifferentialAttentionMonitorCallback,
|
||||
)
|
||||
|
||||
@@ -174,7 +174,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
- Initial lambda value based on layer index.
|
||||
- Rotary position embedding layer.
|
||||
"""
|
||||
self.diff_attn_mix = 0.0 # Default to full mixing
|
||||
self.diff_attn_mix = 1.0 # Default to full mixing
|
||||
|
||||
self.lambda_init = nn.Parameter(
|
||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||
@@ -383,7 +383,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
attn = self.subln(attn)
|
||||
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||
# NOTE: this may need to be added back in, but doesn't interact well with
|
||||
# `diff_attn_mix`.
|
||||
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||
|
||||
return self.o_proj(attn)
|
||||
|
||||
Reference in New Issue
Block a user