diff --git a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py index 2bd6afa5d..a3a699ae2 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py @@ -15,9 +15,9 @@ try: except ImportError: fast_causal_dot_product = None +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv + from .feature_map import init_feature_map, init_learned_kernel -from .rotary import apply_rotary_pos_emb -from .utils import repeat_kv # ------------------- # Attention functions diff --git a/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py index 9ea6c9a90..5e4f0b879 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py @@ -23,13 +23,14 @@ try: except ModuleNotFoundError: _flash_attention_forward = None # Transformers v4.36 +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + # Causal linear attention dot product CUDA kernel from fast-transformers from .linear_attention import ( LinearAttentionState, LolcatsLinearAttention, causal_dot_product, ) -from .rotary import apply_rotary_pos_emb LOG = logging.getLogger(__name__) diff --git a/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py index 79ac8f21c..9c750fadc 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py @@ -22,9 +22,10 @@ try: except ModuleNotFoundError: _flash_attention_forward = None # Transformers v4.36 +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + from .linear_attention import softmax_attention from .linear_window_attention_tk import LolcatsTKWindowAttention -from .rotary import apply_rotary_pos_emb LOG = logging.getLogger( "axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long" diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index abe29db5a..fa30ba5ab 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -357,5 +357,3 @@ def register_linear_llama(): LinearLlamaConfig.register_for_auto_class("AutoConfig") LinearLlamaModel.register_for_auto_class("AutoModel") LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM") - - print("registered transformers") diff --git a/src/axolotl/integrations/lolcats/linear_llama/rotary.py b/src/axolotl/integrations/lolcats/linear_llama/rotary.py deleted file mode 100644 index ed885dcbc..000000000 --- a/src/axolotl/integrations/lolcats/linear_llama/rotary.py +++ /dev/null @@ -1,204 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Rotary embeddings. Same as usual for Transformer models. - -Note these are modified from HF Transformers v4.36, from: -- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py -- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123 -""" -from typing import Optional - -import torch -import torch.nn as nn - - -def get_rotary_embeddings( - rope_scaling_type: Optional[str] = None, - head_dim: int = 128, - max_position_embeddings: int = 4096, - rope_theta: float = 10000.0, - rope_scaling_factor: float = 1.0, - device: Optional[torch.device] = None, -) -> nn.Module: - """Return rotary embedding object""" - if rope_scaling_type is None: - return RotaryEmbedding( - head_dim, - max_position_embeddings=max_position_embeddings, - base=rope_theta, - device=device, - ) - elif rope_scaling_type == "linear": - return LinearScalingRotaryEmbedding( - head_dim, - max_position_embeddings=max_position_embeddings, - scaling_factor=rope_scaling_factor, - base=rope_theta, - device=device, - ) - elif rope_scaling_type == "dynamic": - return DynamicNTKScalingRotaryEmbedding( - head_dim, - max_position_embeddings=max_position_embeddings, - scaling_factor=rope_scaling_factor, - base=rope_theta, - device=device, - ) - else: - raise NotImplementedError( - f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.' - ) - - -# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors.""" - if position_ids is not None: - cos, sin = cos[position_ids], sin[position_ids] - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) -class RotaryEmbedding(nn.Module): - """Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - """ - Compute rotary embeddings - """ - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers/models/llama/modeling_llama.py at v4.36 -class LinearScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -# Copied from transformers/models/llama/modeling_llama.py at v4.36 -class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - ): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) diff --git a/src/axolotl/integrations/lolcats/linear_llama/utils.py b/src/axolotl/integrations/lolcats/linear_llama/utils.py deleted file mode 100644 index 4e0314ce0..000000000 --- a/src/axolotl/integrations/lolcats/linear_llama/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Shared attention helpers -""" - -import torch - - -# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). - The hidden states go from: - (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def mask_attention( - qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000 -) -> torch.Tensor: - """ - Apply attention mask (e.g., for padding) - """ - if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l) - return qk_dot.masked_fill(~attn_mask.bool(), mask_value) - else: - return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)