fixing negative component mixing
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import List
|
|||||||
from transformers import PreTrainedModel, TrainerCallback
|
from transformers import PreTrainedModel, TrainerCallback
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
from axolotl.utils.callbacks.differential import (
|
from axolotl.utils.callbacks.diff_attn import (
|
||||||
DifferentialAttentionMixingCallback,
|
DifferentialAttentionMixingCallback,
|
||||||
DifferentialAttentionMonitorCallback,
|
DifferentialAttentionMonitorCallback,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
- Initial lambda value based on layer index.
|
- Initial lambda value based on layer index.
|
||||||
- Rotary position embedding layer.
|
- Rotary position embedding layer.
|
||||||
"""
|
"""
|
||||||
self.diff_attn_mix = 0.0 # Default to full mixing
|
self.diff_attn_mix = 1.0 # Default to full mixing
|
||||||
|
|
||||||
self.lambda_init = nn.Parameter(
|
self.lambda_init = nn.Parameter(
|
||||||
torch.full((), lambda_init_fn(self.layer_idx)),
|
torch.full((), lambda_init_fn(self.layer_idx)),
|
||||||
@@ -383,7 +383,9 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
|||||||
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
Processed attention output of shape (batch_size, seq_len, hidden_size)
|
||||||
"""
|
"""
|
||||||
attn = self.subln(attn)
|
attn = self.subln(attn)
|
||||||
attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
# NOTE: this may need to be added back in, but doesn't interact well with
|
||||||
|
# `diff_attn_mix`.
|
||||||
|
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
|
||||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||||
|
|
||||||
return self.o_proj(attn)
|
return self.o_proj(attn)
|
||||||
|
|||||||
Reference in New Issue
Block a user