diff --git a/scripts/convert_diff_transformer.py b/scripts/convert_diff_transformer.py new file mode 100644 index 000000000..651c0a229 --- /dev/null +++ b/scripts/convert_diff_transformer.py @@ -0,0 +1,127 @@ +"""Test conversion of transformers model attention to differential attention.""" +from typing import Tuple + +import torch +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, +) + +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention + + +def setup_model( + model_name: str, device: str = "cuda" +) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: + """Load model and tokenizer""" + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map=device, + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return model, tokenizer + + +def convert_model_attention(model: AutoModelForCausalLM) -> AutoModelForCausalLM: + """Convert model to use differential attention""" + try: + model = convert_to_diff_attention(model) + return model + except Exception as exception: + print(f"Error during model conversion: {exception}") + raise + + +def test_inference(model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> None: + """Run test inference""" + # Test prompts + test_prompts = [ + "The quick brown fox", + ] + + for prompt in test_prompts: + try: + # Tokenize + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + # Generate + from time import time + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + # temperature=0.7, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + # use_cache=True, + ) + elasped = time() - start + print(f"generation time: {elasped}s") + + # Decode + print(outputs) + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + print(f"\nPrompt: {prompt}") + print(f"Generated: {generated_text}\n") + + except Exception as exception: + print(f"Error during inference: {str(exception)}") + raise + + +def save_converted_model(model: AutoModelForCausalLM, output_dir: str) -> None: + """Save the converted model""" + print(f"Saving converted model to {output_dir}") + model.save_pretrained(output_dir) + + +def main(): + # Configuration + model_name = "HuggingFaceTB/SmolLM2-135M" + # model_name = "openlm-research/open_llama_3b_v2" + output_dir = "./converted_model" + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + try: + # Load model and tokenizer + model, tokenizer = setup_model(model_name, device) + + # Print original model info + print("Original model config:") + print(f"\t- Hidden size: {model.config.hidden_size}") + print(f"\t- Number of attention heads: {model.config.num_attention_heads}") + + # Test the original model + test_inference(model, tokenizer) + + # Convert to differential attention + model = convert_to_diff_attention(model) + model.to(model.device) + print("Model conversion completed") + + # Test the converted model + test_inference(model, tokenizer) + + # Save converted model + save_converted_model(model, output_dir) + + except Exception as exception: + print(f"Error during test: {str(exception)}") + raise + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 93a8df073..36d97037b 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -1,20 +1,81 @@ """Differential attention conversion logic for a huggingface pre-trained model.""" import logging +from typing import Union +import torch +from torch import nn from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention from transformers.models.mistral.modeling_mistral import MistralAttention from transformers.models.mixtral.modeling_mixtral import MixtralAttention -from .multihead_diffattn import DifferentialAttention +from .multihead_diffattn import ( + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +def copy_attention_weights( + old_attn: Union[LlamaAttention, LlamaSdpaAttention], + new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], + zero_init: bool = True, +) -> None: + """ + Copy weights from old attention layer to new differential attention layer. + Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence + to original attention mechanism. + """ + # For Q projection (Q1 and Q2) + new_q = torch.empty_like(new_attn.q_proj.weight.data) + new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1 + if zero_init: + new_q[new_attn.hidden_size :] = 0 + else: + nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1) + new_attn.q_proj.weight.data.copy_(new_q) + + # For K projection (K1 and K2) + old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads + new_k = torch.empty_like(new_attn.k_proj.weight.data) + new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 + if zero_init: + new_k[old_kv_size:] = 0 + else: + nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1) + new_attn.k_proj.weight.data.copy_(new_k) + + # For V projection (single V) + new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) + + # Output projection remains the same + new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) + + # Zero out lambda parameters for exact equivalence + if zero_init: + nn.init.zeros_(new_attn.lambda_q1) + nn.init.zeros_(new_attn.lambda_k1) + nn.init.zeros_(new_attn.lambda_q2) + nn.init.zeros_(new_attn.lambda_k2) + new_attn.lambda_init = 0.0 + + logger.debug( + "Copied positive attention weights from %s to %s", + type(old_attn).__name__, + type(new_attn).__name__, + ) + + def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" - attention_patterns = (LlamaAttention, MistralAttention, MixtralAttention) + attention_patterns = ( + LlamaAttention, + LlamaSdpaAttention, + MistralAttention, + MixtralAttention, + ) layer_idx = 0 # Get model dtype from existing weights @@ -29,13 +90,22 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: layer_type = type(child).__name__ logger.info(f"Converting attention layer {layer_idx}: {layer_type}") + # Choose appropriate differential attention class + if isinstance(child, LlamaSdpaAttention): + attention_class = LlamaDifferentialSdpaAttention + else: + attention_class = LlamaDifferentialAttention + # Create new diff attn layer - new_attention = DifferentialAttention( + new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, layer_idx=layer_idx, dtype=model_dtype, ) + # Copy weights from old attention to new attention + copy_attention_weights(child, new_attention) + # Replace the layer setattr(module, name, new_attention) layer_idx += 1 diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py index 00462475e..6d3bc7589 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/diff_transformer/multihead_diffattn.py @@ -6,9 +6,9 @@ from typing import Any, Optional, Tuple import torch import torch.nn.functional as F +import transformers from torch import nn from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import LlamaRMSNorm as RMSNorm from transformers.models.llama.modeling_llama import ( LlamaRotaryEmbedding, apply_rotary_pos_emb, @@ -34,7 +34,7 @@ def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) -class DifferentialAttention(nn.Module): +class LlamaDifferentialAttention(nn.Module): """Differential Attention implementation as described in the Diff Transformer paper. This implements a modified attention mechanism that computes the difference between @@ -54,7 +54,6 @@ class DifferentialAttention(nn.Module): config: Model configuration object containing hidden size, number of heads etc. layer_idx: Index of this layer in the transformer stack dtype: Data type for the layer parameters - is_causal: Whether to use causal (masked) attention """ def __init__( @@ -62,43 +61,52 @@ class DifferentialAttention(nn.Module): config: Any, layer_idx: int, dtype: torch.dtype, - is_causal: bool = True, ): super().__init__() - self.config = config - self.layer_idx = layer_idx + # Base model dimensions + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.is_causal = is_causal - # self.head_dim = self.hidden_size // self.num_heads - self.head_dim = self.hidden_size // self.num_heads // 2 - self.num_key_value_heads = getattr( - config, "num_key_value_heads", self.num_heads - ) - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.scaling = (self.head_dim) ** -0.5 + self.base_num_heads = config.num_attention_heads + self.base_num_kv_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads - # Initialize projections with correct dtype + self.scaling = self.head_dim**-0.5 + self.layer_idx = layer_idx + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + # For Q1 and Q2 self.q_proj = nn.Linear( - self.hidden_size, self.hidden_size, bias=False, dtype=dtype + self.hidden_size, + self.hidden_size * 2, + bias=False, + dtype=dtype, ) + + # For K1 and K2 self.k_proj = nn.Linear( self.hidden_size, - self.hidden_size // self.num_key_value_groups, - bias=False, - dtype=dtype, - ) - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.num_key_value_groups, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, bias=False, dtype=dtype, ) + # Single V projection + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + bias=False, + dtype=dtype, + ) + + # Output projection self.o_proj = nn.Linear( - self.hidden_size, self.hidden_size, bias=False, dtype=dtype + self.hidden_size, + self.hidden_size, + bias=False, + dtype=dtype, ) # Initialize differential attention parameters @@ -116,7 +124,6 @@ class DifferentialAttention(nn.Module): torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) ) - self.subln = RMSNorm(2 * self.head_dim, eps=1e-5) self.rotary_emb = LlamaRotaryEmbedding(config=config) def forward( @@ -126,6 +133,7 @@ class DifferentialAttention(nn.Module): position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, # pylint: disable=unused-argument @@ -134,97 +142,261 @@ class DifferentialAttention(nn.Module): Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]], ]: - bsz, tgt_len, _ = hidden_states.size() + bsz, q_len, _ = hidden_states.size() - # Project queries, keys and values - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - # Reshape for attention - q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(bsz, tgt_len, 2 * self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - v = v.view(bsz, tgt_len, self.num_key_value_heads, 2 * self.head_dim).transpose( - 1, 2 - ) + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) - # Generate or unpack cos, sin for rotary positional embeddings + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Reshape V + v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings if position_embeddings is None: if position_ids is None: - position_ids = torch.arange( - 0, tgt_len, dtype=torch.long, device=q.device - ) - cos, sin = self.rotary_emb(q, position_ids) + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) else: cos, sin = position_embeddings - # Need to adjust cos, sin to match the halved head_dim - cos = cos[..., : self.head_dim] - sin = sin[..., : self.head_dim] - q, k = apply_rotary_pos_emb(q, k, cos, sin) + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - - # Update cache and get back concatenated states + k = torch.stack([k1, k2], dim=1) k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) - # Prepare for attention - k = repeat_kv(k, self.num_key_value_groups) - v = repeat_kv(v, self.num_key_value_groups) + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - # Scale query - q = q * self.scaling + # Calculate attention scores for both parts + # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor + # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. + attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) + attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( + self.head_dim + ) - # Calculate attention scores - attn_weights = torch.matmul(q, k.transpose(-1, -2)) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : k1.shape[-2]] + attn_weights1 = attn_weights1 + causal_mask + attn_weights2 = attn_weights2 + causal_mask - # Apply causal mask - if attention_mask is None: - attention_mask = torch.triu( - torch.full((tgt_len, tgt_len), float("-inf"), device=q.device), - diagonal=1, - ).type_as(attn_weights) - attn_weights = torch.nan_to_num(attn_weights) - attn_weights = attn_weights + attention_mask - - # Apply softmax - attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as( - attn_weights + # Apply softmax separately as per paper + attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as( + attn_weights1 + ) + attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as( + attn_weights2 + ) + attn_weights1 = F.dropout( + attn_weights1, p=self.attention_dropout, training=self.training + ) + attn_weights2 = F.dropout( + attn_weights2, p=self.attention_dropout, training=self.training ) # Calculate lambda lambda_1 = torch.exp( torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q) + ).type_as(q1) lambda_2 = torch.exp( torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q) + ).type_as(q1) lambda_full = lambda_1 - lambda_2 + self.lambda_init - # Apply differential attention - attn_weights = attn_weights.view( - bsz, self.num_heads, 2, -1, attn_weights.size(-1) - ) - attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] + # Compute differential attention (following paper's formula) + attn_weights = attn_weights1 - lambda_full * attn_weights2 - # Apply attention to values + # Apply attention weights to values attn = torch.matmul(attn_weights, v) - # Apply sublayer norm - attn = self.subln(attn).type_as(attn) + # Apply sublayer norm and scaling + # NOTE(Dan): The differential transformers paper applies sublayer normalization at this + # point, but this is typically done outside of the attention layer. It would look something + # like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar. attn = attn * (1 - self.lambda_init) - # Reshape and project output - attn = attn.transpose(1, 2).reshape( - bsz, tgt_len, self.num_heads * 2 * self.head_dim - ) + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) attn = self.o_proj(attn) - # Return in exact format expected by LLaMA if output_attentions: return attn, attn_weights, past_key_value return attn, None, past_key_value + + +class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention): + """Differential Attention implementation as described in the Diff Transformer paper. + This implements the same logic as `LlamaDifferentialAttention`, but uses + `scaled_dot_product_attention` instead of "manually" computing it under the hood. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + """ + + def forward( + self, + hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + if output_attentions: + transformers.logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + # Reshape V + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + causal_mask = None + if attention_mask is not None: + causal_mask = attention_mask + causal_mask = causal_mask[:, :, :, : k1.shape[-2]] + + # SDPA with memory-efficient backend requires contiguous inputs on CUDA + if q1.device.type == "cuda" and causal_mask is not None: + q1, q2 = q1.contiguous(), q2.contiguous() + k1, k2 = k1.contiguous(), k2.contiguous() + v = v.contiguous() + + # Calculate attention using SDPA + is_causal = attention_mask is None and q_len > 1 + + attn_output1 = F.scaled_dot_product_attention( + q1, + k1, + v, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output2 = F.scaled_dot_product_attention( + q2, + k2, + v, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Combine the attention outputs + attn = attn_output1 - lambda_full * attn_output2 + + # Apply sublayer norm and scaling + attn = attn * (1 - self.lambda_init) + + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn = self.o_proj(attn) + + if output_attentions: + return ( + attn, + None, + past_key_value, + ) # Note: can't return attn_weights with SDPA + return attn, None, past_key_value