fix: remove redundant files
This commit is contained in:
@@ -15,9 +15,9 @@ try:
|
||||
except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from .feature_map import init_feature_map, init_learned_kernel
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
from .utils import repeat_kv
|
||||
|
||||
# -------------------
|
||||
# Attention functions
|
||||
|
||||
@@ -23,13 +23,14 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
causal_dot_product,
|
||||
)
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,9 +22,10 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
from .linear_attention import softmax_attention
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||
|
||||
@@ -357,5 +357,3 @@ def register_linear_llama():
|
||||
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
|
||||
print("registered transformers")
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Rotary embeddings. Same as usual for Transformer models.
|
||||
|
||||
Note these are modified from HF Transformers v4.36, from:
|
||||
- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py
|
||||
- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def get_rotary_embeddings(
|
||||
rope_scaling_type: Optional[str] = None,
|
||||
head_dim: int = 128,
|
||||
max_position_embeddings: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling_factor: float = 1.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> nn.Module:
|
||||
"""Return rotary embedding object"""
|
||||
if rope_scaling_type is None:
|
||||
return RotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
device=device,
|
||||
)
|
||||
elif rope_scaling_type == "linear":
|
||||
return LinearScalingRotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
scaling_factor=rope_scaling_factor,
|
||||
base=rope_theta,
|
||||
device=device,
|
||||
)
|
||||
elif rope_scaling_type == "dynamic":
|
||||
return DynamicNTKScalingRotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
scaling_factor=rope_scaling_factor,
|
||||
base=rope_theta,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.'
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||
if position_ids is not None:
|
||||
cos, sin = cos[position_ids], sin[position_ids]
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (
|
||||
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings,
|
||||
device=self.inv_freq.device,
|
||||
dtype=torch.get_default_dtype(),
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(
|
||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
"""
|
||||
Compute rotary embeddings
|
||||
"""
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||
|
||||
return (
|
||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(
|
||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
|
||||
|
||||
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
max_position_embeddings=2048,
|
||||
base=10000,
|
||||
device=None,
|
||||
scaling_factor=1.0,
|
||||
):
|
||||
self.scaling_factor = scaling_factor
|
||||
super().__init__(dim, max_position_embeddings, base, device)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
||||
- (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
||||
)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
t = torch.arange(
|
||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
Shared attention helpers
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
The hidden states go from:
|
||||
(batch, num_key_value_heads, seqlen, head_dim) to
|
||||
(batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def mask_attention(
|
||||
qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply attention mask (e.g., for padding)
|
||||
"""
|
||||
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
|
||||
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
|
||||
else:
|
||||
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)
|
||||
Reference in New Issue
Block a user