added modeling code; cleanup + refactor
This commit is contained in:
committed by
Dan Saunders
parent
fcbfa86373
commit
5b90da0be3
@@ -50,7 +50,7 @@ def copy_attention_weights(
|
|||||||
new_attn.q_proj.weight.data.copy_(new_q)
|
new_attn.q_proj.weight.data.copy_(new_q)
|
||||||
|
|
||||||
# For K projection (K1 and K2)
|
# For K projection (K1 and K2)
|
||||||
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
|
old_kv_size = old_attn.k_proj.weight.data.size(0)
|
||||||
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
new_k = torch.empty_like(new_attn.k_proj.weight.data)
|
||||||
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
|
||||||
if zero_init:
|
if zero_init:
|
||||||
@@ -99,6 +99,7 @@ def convert_to_diff_attn(
|
|||||||
# Iterate through module children, convert any attn layers to diff attn
|
# Iterate through module children, convert any attn layers to diff attn
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
child_class_name = type(child).__name__
|
child_class_name = type(child).__name__
|
||||||
|
|
||||||
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
|
if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
|
||||||
# Find matching attention class by name
|
# Find matching attention class by name
|
||||||
for orig_class, diff_class in ATTENTION_MAPPING.items():
|
for orig_class, diff_class in ATTENTION_MAPPING.items():
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Any, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
@@ -17,7 +16,14 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
FLASH_ATTENTION_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
FLASH_ATTENTION_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@@ -35,11 +41,12 @@ def lambda_init_fn(depth):
|
|||||||
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
class DifferentialAttentionBase(nn.Module):
|
class LlamaDifferentialAttentionBase(nn.Module):
|
||||||
"""Base class for differential attention implementations."""
|
"""Base class for differential attention implementations."""
|
||||||
|
|
||||||
def __init__(self, config: Any, layer_idx: int):
|
def __init__(self, config: Any, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
self._init_config(config, layer_idx)
|
self._init_config(config, layer_idx)
|
||||||
self._init_projections()
|
self._init_projections()
|
||||||
self._init_differential_params()
|
self._init_differential_params()
|
||||||
@@ -59,9 +66,9 @@ class DifferentialAttentionBase(nn.Module):
|
|||||||
|
|
||||||
if config.split_heads:
|
if config.split_heads:
|
||||||
# Split heads mode - single projections
|
# Split heads mode - single projections
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads // 2
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||||
# implementation, which asserts `self.base_num_heads` is even.
|
# implementation, which asserts `self.base_num_heads` is even
|
||||||
self.heads_per_component = self.base_num_heads // 2
|
self.heads_per_component = self.base_num_heads // 2
|
||||||
self.value_head_dim = 2 * self.head_dim
|
self.value_head_dim = 2 * self.head_dim
|
||||||
else:
|
else:
|
||||||
@@ -110,36 +117,43 @@ class DifferentialAttentionBase(nn.Module):
|
|||||||
self.lambda_k2 = nn.Parameter(
|
self.lambda_k2 = nn.Parameter(
|
||||||
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
|
||||||
)
|
)
|
||||||
self.rotary_emb = LlamaRotaryEmbedding(
|
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||||
self.max_position_embeddings, self.head_dim, self.rope_theta
|
|
||||||
)
|
|
||||||
|
|
||||||
def _init_normalization(self, config):
|
def _init_normalization(self, config):
|
||||||
"""Initialize normalization layers."""
|
"""Initialize normalization layers."""
|
||||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||||
self.subln = (
|
if sublayer_norm:
|
||||||
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
|
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
|
||||||
if sublayer_norm
|
else:
|
||||||
else nn.Identity()
|
self.subln = nn.Identity()
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
|
||||||
"""Prepare inputs for attention computation."""
|
"""Prepare inputs for attention computation."""
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# Project and split
|
# Project and split
|
||||||
qp = self.q_proj(hidden_states)
|
q = self.q_proj(hidden_states)
|
||||||
kp = self.k_proj(hidden_states)
|
k = self.k_proj(hidden_states)
|
||||||
v = self.v_proj(hidden_states)
|
v = self.v_proj(hidden_states)
|
||||||
q1, q2 = qp.chunk(2, dim=-1)
|
q1, q2 = q.chunk(2, dim=-1)
|
||||||
k1, k2 = kp.chunk(2, dim=-1)
|
k1, k2 = k.chunk(2, dim=-1)
|
||||||
|
|
||||||
# Reshape
|
# Reshape
|
||||||
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||||
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
1, 2
|
||||||
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
)
|
||||||
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||||
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
|
1, 2
|
||||||
|
)
|
||||||
|
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
|
||||||
return q1, q2, k1, k2, v
|
return q1, q2, k1, k2, v
|
||||||
|
|
||||||
@@ -148,16 +162,16 @@ class DifferentialAttentionBase(nn.Module):
|
|||||||
):
|
):
|
||||||
"""Apply rotary embeddings to queries and keys."""
|
"""Apply rotary embeddings to queries and keys."""
|
||||||
if position_embeddings is None:
|
if position_embeddings is None:
|
||||||
if position_ids is None:
|
LOG.warning(
|
||||||
position_ids = torch.arange(q1.size(-2), device=q1.device)
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
||||||
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
||||||
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
||||||
|
"removed and `position_embeddings` will be mandatory."
|
||||||
|
)
|
||||||
cos, sin = self.rotary_emb(q1, position_ids)
|
cos, sin = self.rotary_emb(q1, position_ids)
|
||||||
else:
|
else:
|
||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
if self.split_heads:
|
|
||||||
cos, _ = cos.chunk(2, dim=2)
|
|
||||||
sin, _ = sin.chunk(2, dim=2)
|
|
||||||
|
|
||||||
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
|
||||||
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
|
||||||
|
|
||||||
@@ -195,7 +209,7 @@ class DifferentialAttentionBase(nn.Module):
|
|||||||
return self.o_proj(attn)
|
return self.o_proj(attn)
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialAttention(DifferentialAttentionBase):
|
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||||
"""Standard implementation of differential attention."""
|
"""Standard implementation of differential attention."""
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -237,15 +251,16 @@ class LlamaDifferentialAttention(DifferentialAttentionBase):
|
|||||||
|
|
||||||
lambda_full = self._compute_lambda(q1)
|
lambda_full = self._compute_lambda(q1)
|
||||||
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
|
||||||
|
|
||||||
attn = self._process_attention_output(attn, bsz, q_len)
|
attn = self._process_attention_output(attn, bsz, q_len)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
return attn, attn1 - lambda_full * attn2, past_key_value
|
attn_weights = attn1 - lambda_full * attn2
|
||||||
|
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
|
||||||
|
return attn, attn_weights, past_key_value
|
||||||
return attn, None, past_key_value
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
|
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||||
"""SDPA-based implementation of differential attention."""
|
"""SDPA-based implementation of differential attention."""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -262,6 +277,11 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
|
|||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
|
LOG.warning(
|
||||||
|
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
|
||||||
|
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
|
||||||
|
+ "`output_attentions=True`. Falling back to the eager attention implementation."
|
||||||
|
)
|
||||||
return LlamaDifferentialAttention.forward(
|
return LlamaDifferentialAttention.forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -309,9 +329,18 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
|
|||||||
return attn, None, past_key_value
|
return attn, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
|
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||||
"""Flash Attention 2-based implementation of differential attention."""
|
"""Flash Attention 2-based implementation of differential attention."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if not FLASH_ATTENTION_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
|
||||||
|
"Please install with `pip install flash-attn --no-build-isolation`"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -326,6 +355,11 @@ class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
|
|||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
|
LOG.warning(
|
||||||
|
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
|
||||||
|
+ "flash attenion does not support `output_attentions=True`. Falling back "
|
||||||
|
+ "to the eager attention implementation."
|
||||||
|
)
|
||||||
return LlamaDifferentialAttention.forward(
|
return LlamaDifferentialAttention.forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
370
src/axolotl/integrations/diff_transformer/modeling_diff_attn.py
Normal file
370
src/axolotl/integrations/diff_transformer/modeling_diff_attn.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""Modeling for differential transformers."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LlamaMLP,
|
||||||
|
LlamaModel,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .diff_attn import (
|
||||||
|
LlamaDifferentialAttention,
|
||||||
|
LlamaDifferentialAttentionBase,
|
||||||
|
LlamaDifferentialFlashAttention2,
|
||||||
|
LlamaDifferentialSdpaAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialConfig(LlamaConfig):
|
||||||
|
"""Configuration class for Differential LLaMA model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
split_heads: bool = False,
|
||||||
|
sublayer_norm: bool = True,
|
||||||
|
zero_init: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.split_heads = split_heads
|
||||||
|
self.sublayer_norm = sublayer_norm
|
||||||
|
self.zero_init = zero_init
|
||||||
|
self.architectures = ["LlamaDifferentialModel"]
|
||||||
|
self._attn_implementations = {
|
||||||
|
"eager": "differential_eager",
|
||||||
|
"sdpa": "differential_sdpa",
|
||||||
|
"flash_attention_2": "differential_flash_attention_2",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
|
||||||
|
"""Base class for differential LLaMA models."""
|
||||||
|
|
||||||
|
config_class = LlamaDifferentialConfig
|
||||||
|
base_model_prefix = "llama_differential"
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
|
||||||
|
def lambda_init_fn(depth: int) -> float:
|
||||||
|
"""Initialize lambda parameter based on layer depth."""
|
||||||
|
return 0.8 - 0.6 * math.exp(-0.3 * depth)
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialModel(LlamaDifferentialPreTrainedModel):
|
||||||
|
"""Differential version of the LLaMA model."""
|
||||||
|
|
||||||
|
def __init__(self, config: LlamaDifferentialConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
# Map attn implementations to classes
|
||||||
|
self.attn_implementation_to_class = {
|
||||||
|
"differential_eager": LlamaDifferentialAttention,
|
||||||
|
"differential_sdpa": LlamaDifferentialSdpaAttention,
|
||||||
|
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get correct attention implementation
|
||||||
|
attn_implementation = getattr(config, "_attn_implementation", "eager")
|
||||||
|
if attn_implementation in config._attn_implementations:
|
||||||
|
attn_implementation = config._attn_implementations[attn_implementation]
|
||||||
|
|
||||||
|
self.attention_class = self.attn_implementation_to_class.get(
|
||||||
|
attn_implementation, LlamaDifferentialAttention
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model components
|
||||||
|
self.embed_tokens = nn.Embedding(
|
||||||
|
config.vocab_size, config.hidden_size, config.pad_token_id
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
LlamaDifferentialDecoderLayer(
|
||||||
|
config=config, layer_idx=i, attention_class=self.attention_class
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if either input_ids or inputs_embeds is provided
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
device = inputs_embeds.device
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
# Initialize past_key_values if needed
|
||||||
|
if past_key_values is None:
|
||||||
|
past_key_values = tuple([None] * len(self.layers))
|
||||||
|
|
||||||
|
# Create attention mask if not provided
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = self._prepare_attention_mask(
|
||||||
|
attention_mask, (batch_size, seq_length), device
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
# Initialize lists to store outputs
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
next_cache = () if use_cache else None
|
||||||
|
|
||||||
|
for _, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,) # type: ignore
|
||||||
|
|
||||||
|
layer_outputs = layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_cache += (layer_outputs[-1],) # type: ignore
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],) # type: ignore
|
||||||
|
|
||||||
|
# Add last hidden state
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,) # type: ignore
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_attention_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_shape: Tuple[int, int],
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Prepare attention mask for computing attention."""
|
||||||
|
# Create causal mask
|
||||||
|
# [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length]
|
||||||
|
combined_attention_mask = None
|
||||||
|
_, seq_length = input_shape
|
||||||
|
|
||||||
|
if self.config.is_decoder:
|
||||||
|
seq_ids = torch.arange(seq_length, device=device)
|
||||||
|
causal_mask = (
|
||||||
|
seq_ids[None, None, :].repeat(1, seq_length, 1)
|
||||||
|
<= seq_ids[None, :, None]
|
||||||
|
)
|
||||||
|
causal_mask = causal_mask.to(attention_mask.dtype)
|
||||||
|
|
||||||
|
if causal_mask.shape[1:] != (seq_length, seq_length):
|
||||||
|
causal_mask = causal_mask[:, :seq_length, :seq_length]
|
||||||
|
|
||||||
|
# Extend attention mask
|
||||||
|
combined_attention_mask = (
|
||||||
|
causal_mask[None, None, :, :] * attention_mask[:, None, None, :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined_attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
|
return combined_attention_mask
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llama(
|
||||||
|
cls,
|
||||||
|
llama_model: LlamaModel,
|
||||||
|
differential_config: Optional[LlamaDifferentialConfig] = None,
|
||||||
|
) -> "LlamaDifferentialModel":
|
||||||
|
"""Convert a standard LLaMA model to use differential attention."""
|
||||||
|
if differential_config is None:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
differential_config = LlamaDifferentialConfig.from_pretrained(
|
||||||
|
llama_model.config._name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new model
|
||||||
|
new_model = cls(differential_config)
|
||||||
|
|
||||||
|
# Copy non-attention weights directly
|
||||||
|
new_model.embed_tokens.load_state_dict(llama_model.embed_tokens.state_dict())
|
||||||
|
new_model.norm.load_state_dict(llama_model.norm.state_dict())
|
||||||
|
|
||||||
|
# Copy layer weights, handling attention layers specially
|
||||||
|
for new_layer, old_layer in zip(new_model.layers, llama_model.layers):
|
||||||
|
# Copy self-attention weights with special handling
|
||||||
|
if differential_config.split_heads:
|
||||||
|
# Split heads mode
|
||||||
|
new_layer.self_attn.q_proj.weight.data.copy_(
|
||||||
|
old_layer.self_attn.q_proj.weight.data
|
||||||
|
)
|
||||||
|
new_layer.self_attn.k_proj.weight.data.copy_(
|
||||||
|
old_layer.self_attn.k_proj.weight.data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Double projection mode - copy weights to positive components
|
||||||
|
new_layer.self_attn.q_proj.weight.data[
|
||||||
|
: differential_config.hidden_size
|
||||||
|
].copy_(old_layer.self_attn.q_proj.weight.data)
|
||||||
|
new_layer.self_attn.k_proj.weight.data[
|
||||||
|
: differential_config.hidden_size
|
||||||
|
].copy_(old_layer.self_attn.k_proj.weight.data)
|
||||||
|
|
||||||
|
# Zero out relevant parameters for exact equivalence
|
||||||
|
if differential_config.zero_init:
|
||||||
|
old_kv_size = old_layer.self_attn.k_proj.weight.data.size(0)
|
||||||
|
new_layer.self_attn.q_proj.weight.data[
|
||||||
|
new_layer.self_attn.hidden_size :
|
||||||
|
] = 0
|
||||||
|
new_layer.self_attn.k_proj.weight.data[old_kv_size:] = 0
|
||||||
|
nn.init.zeros_(new_layer.self_attn.lambda_q1)
|
||||||
|
nn.init.zeros_(new_layer.self_attn.lambda_k1)
|
||||||
|
nn.init.zeros_(new_layer.self_attn.lambda_q2)
|
||||||
|
nn.init.zeros_(new_layer.self_attn.lambda_k2)
|
||||||
|
nn.init.zeros_(new_layer.self_attn.lambda_init)
|
||||||
|
|
||||||
|
# Copy remaining weights
|
||||||
|
new_layer.self_attn.v_proj.load_state_dict(
|
||||||
|
old_layer.self_attn.v_proj.state_dict()
|
||||||
|
)
|
||||||
|
new_layer.self_attn.o_proj.load_state_dict(
|
||||||
|
old_layer.self_attn.o_proj.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy MLP and layer norm weights
|
||||||
|
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
|
||||||
|
new_layer.input_layernorm.load_state_dict(
|
||||||
|
old_layer.input_layernorm.state_dict()
|
||||||
|
)
|
||||||
|
new_layer.post_attention_layernorm.load_state_dict(
|
||||||
|
old_layer.post_attention_layernorm.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_model
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaDifferentialDecoderLayer(nn.Module):
|
||||||
|
"""Custom decoder layer for diffrential Llama model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config: LlamaDifferentialConfig, layer_idx: int, attention_class
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.self_attn = attention_class(config, layer_idx)
|
||||||
|
self.mlp = LlamaMLP(config)
|
||||||
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = LlamaRMSNorm(
|
||||||
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Layer forward pass with differential attention.
|
||||||
|
"""
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,) # type: ignore
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,) # type: ignore
|
||||||
|
|
||||||
|
return outputs # type: ignore
|
||||||
Reference in New Issue
Block a user