fix: remove redundant files

This commit is contained in:
NanoCode012
2025-02-05 19:34:06 +07:00
parent 0f82bd2d18
commit c4cb622590
6 changed files with 6 additions and 244 deletions

View File

@@ -15,9 +15,9 @@ try:
except ImportError:
fast_causal_dot_product = None
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from .feature_map import init_feature_map, init_learned_kernel
from .rotary import apply_rotary_pos_emb
from .utils import repeat_kv
# -------------------
# Attention functions

View File

@@ -23,13 +23,14 @@ try:
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
causal_dot_product,
)
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger(__name__)

View File

@@ -22,9 +22,10 @@ try:
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"

View File

@@ -357,5 +357,3 @@ def register_linear_llama():
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
print("registered transformers")

View File

@@ -1,204 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rotary embeddings. Same as usual for Transformer models.
Note these are modified from HF Transformers v4.36, from:
- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py
- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123
"""
from typing import Optional
import torch
import torch.nn as nn
def get_rotary_embeddings(
rope_scaling_type: Optional[str] = None,
head_dim: int = 128,
max_position_embeddings: int = 4096,
rope_theta: float = 10000.0,
rope_scaling_factor: float = 1.0,
device: Optional[torch.device] = None,
) -> nn.Module:
"""Return rotary embedding object"""
if rope_scaling_type is None:
return RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
device=device,
)
elif rope_scaling_type == "linear":
return LinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling_factor,
base=rope_theta,
device=device,
)
elif rope_scaling_type == "dynamic":
return DynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling_factor,
base=rope_theta,
device=device,
)
else:
raise NotImplementedError(
f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.'
)
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
if position_ids is not None:
cos, sin = cos[position_ids], sin[position_ids]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
class RotaryEmbedding(nn.Module):
"""Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
"""
Compute rotary embeddings
"""
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Copied from transformers/models/llama/modeling_llama.py at v4.36
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers/models/llama/modeling_llama.py at v4.36
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

View File

@@ -1,34 +0,0 @@
"""
Shared attention helpers
"""
import torch
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from:
(batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def mask_attention(
qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000
) -> torch.Tensor:
"""
Apply attention mask (e.g., for padding)
"""
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
else:
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)