feat: refactor into modeling code
This commit is contained in:
@@ -6,6 +6,7 @@ from .linear_window_attention_sw import (
|
||||
LinearAttentionSlidingWindowCache,
|
||||
LolcatsSlidingWindowAttention,
|
||||
)
|
||||
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||
from .linear_window_attention_tk import (
|
||||
LinearAttentionTKWindowCache,
|
||||
@@ -16,7 +16,7 @@ except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from ..model.feature_map import init_feature_map, init_learned_kernel
|
||||
from ..model.rotary import apply_rotary_pos_emb, get_rotary_embeddings
|
||||
from ..model.rotary import apply_rotary_pos_emb
|
||||
from .utils import repeat_kv
|
||||
|
||||
# -------------------
|
||||
@@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Linear LLaMA model configuration"""
|
||||
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class LinearLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`LinearLlamaModel`].
|
||||
It is a modified LlamaConfig that includes additional parameters for linear attention.
|
||||
|
||||
Args:
|
||||
attention_config (`dict`):
|
||||
Dictionary containing the configuration for linear attention mechanism.
|
||||
Expected contents:
|
||||
`feature_map` (`str`):
|
||||
The type of feature map to use for linear attention.
|
||||
`feature_map_kwargs` (`dict`):
|
||||
Additional arguments for the feature map.
|
||||
`learned_kernel` (`str`, *optional*):
|
||||
Type of learned kernel to use, if any.
|
||||
`learned_kernel_kwargs` (`dict`, *optional*):
|
||||
Additional arguments for the learned kernel.
|
||||
`tie_qk_kernels` (`bool`, *optional*, defaults to False):
|
||||
Whether to tie query and key kernels.
|
||||
`rotary_config` (`dict`, *optional*):
|
||||
Configuration for rotary embeddings.
|
||||
`train_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to train attention to match softmax attention.
|
||||
`remove_base_attn` (`bool`, *optional*, defaults to True):
|
||||
Whether to remove base attention after initialization.
|
||||
`mask_value` (`int`, *optional*, defaults to 0):
|
||||
Value to use for masking.
|
||||
`eps` (`float`, *optional*, defaults to 1e-12):
|
||||
Epsilon value for numerical stability.
|
||||
`fp32_attention` (`bool`, *optional*, defaults to False):
|
||||
Whether to use fp32 precision for attention computation.
|
||||
`track_state_grads` (`bool`, *optional*, defaults to False):
|
||||
Whether to track gradients of attention states.
|
||||
|
||||
**kwargs:
|
||||
Additional arguments inherited from LlamaConfig.
|
||||
"""
|
||||
|
||||
model_type = "linear_llama"
|
||||
|
||||
def __init__(self, attention_config: dict, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config
|
||||
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
30
src/axolotl/integrations/lolcats/linear_llama/csrc/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Causal linear attention CUDA kernel
|
||||
|
||||
Usage:
|
||||
```bash
|
||||
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||
|
||||
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||
# nano setup.py
|
||||
|
||||
# Build the CUDA kernel
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
Reference: https://github.com/idiap/fast-transformers/
|
||||
|
||||
```bib
|
||||
@inproceedings{katharopoulos_et_al_2020,
|
||||
author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.},
|
||||
title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention},
|
||||
booktitle = {Proceedings of the International Conference on Machine Learning (ICML)},
|
||||
year = {2020}
|
||||
}
|
||||
|
||||
@article{vyas_et_al_2020,
|
||||
author={Vyas, A. and Katharopoulos, A. and Fleuret, F.},
|
||||
title={Fast Transformers with Clustered Attention},
|
||||
booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
from .causal_attention import causal_dot_product
|
||||
@@ -0,0 +1,225 @@
|
||||
//
|
||||
// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
// Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
// Apoorv Vyas <avyas@idiap.ch>
|
||||
//
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
/**
|
||||
* Compute a*b^T and save it into out.
|
||||
*
|
||||
* a \in R^A
|
||||
* b \in R^B
|
||||
*/
|
||||
inline void vvt_dot(float *a, float *b, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float * bi = b;
|
||||
for (int j=0; j<B; j++) {
|
||||
*out += (*a) * (*bi);
|
||||
out++;
|
||||
bi++;
|
||||
}
|
||||
a++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector matrix product v*m and save it into out.
|
||||
*
|
||||
* v \in R^A
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vm_dot(float *v, float *m, float *out, int A, int B) {
|
||||
// TODO: Consider removing the zeroing part and assuming out already
|
||||
// contains 0s
|
||||
for (int i=0; i<B; i++) {
|
||||
out[i] = 0;
|
||||
}
|
||||
|
||||
for (int i=0; i<A; i++) {
|
||||
float *oi = out;
|
||||
for (int j=0; j<B; j++) {
|
||||
*oi += (*v) * (*m);
|
||||
oi++;
|
||||
m++;
|
||||
}
|
||||
v++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implement a vector transposed-matrix product and save it into out.
|
||||
*
|
||||
* v \in R^B
|
||||
* m \in R^{AxB}
|
||||
*/
|
||||
inline void vmt_dot(float *v, float *m, float *out, int A, int B) {
|
||||
for (int i=0; i<A; i++) {
|
||||
float *vi = v;
|
||||
float s = 0;
|
||||
for (int j=0; j<B; j++) {
|
||||
s += (*vi) * (*m);
|
||||
vi++;
|
||||
m++;
|
||||
}
|
||||
// TODO: Should we be aggregating? See the comment on vm_dot.
|
||||
*out = s;
|
||||
out++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the causally masked dot products of queries, keys and values.
|
||||
*
|
||||
* Basically compute V_j' = (Q_{0:j} * K_{0:j}^T) * V_{0:j} for all j. The
|
||||
* computation is done efficiently by changing the order of the dot products.
|
||||
*/
|
||||
void causal_dot_product(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
torch::Tensor product
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto pa = product.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&qa[n][h][l][0],
|
||||
kvp,
|
||||
&pa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the gradients of queries, keys and values given the gradient of the
|
||||
* causal_dot_product output.
|
||||
*
|
||||
* Make sure that everything is computed in O(N D^2) complexity.
|
||||
*/
|
||||
void causal_dot_backward(
|
||||
const torch::Tensor queries,
|
||||
const torch::Tensor keys,
|
||||
const torch::Tensor values,
|
||||
const torch::Tensor grad_out,
|
||||
torch::Tensor grad_queries,
|
||||
torch::Tensor grad_keys,
|
||||
torch::Tensor grad_values
|
||||
) {
|
||||
// Extract some shapes
|
||||
int N = queries.size(0);
|
||||
int H = queries.size(1);
|
||||
int L = queries.size(2);
|
||||
int E = queries.size(3);
|
||||
int M = values.size(3);
|
||||
|
||||
// Create accessors for all the arguments
|
||||
auto qa = queries.accessor<float, 4>();
|
||||
auto ka = keys.accessor<float, 4>();
|
||||
auto va = values.accessor<float, 4>();
|
||||
auto ga = grad_out.accessor<float, 4>();
|
||||
auto gqa = grad_queries.accessor<float, 4>();
|
||||
auto gka = grad_keys.accessor<float, 4>();
|
||||
auto gva = grad_values.accessor<float, 4>();
|
||||
|
||||
#pragma omp parallel for collapse(2)
|
||||
for (int n=0; n<N; n++) {
|
||||
for (int h=0; h<H; h++) {
|
||||
auto kv = torch::zeros({E, M}, queries.options());
|
||||
float *kvp = kv.data_ptr<float>();
|
||||
|
||||
// Compute the gradient wrt the queries
|
||||
for (int l=0; l<L; l++) {
|
||||
vvt_dot(
|
||||
&ka[n][h][l][0],
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
&gqa[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
|
||||
// Compute the gradient wrt the keys and values
|
||||
kv.zero_();
|
||||
for (int l=L-1; l>=0; l--) {
|
||||
vvt_dot(
|
||||
&qa[n][h][l][0],
|
||||
&ga[n][h][l][0],
|
||||
kvp,
|
||||
E,
|
||||
M
|
||||
);
|
||||
vmt_dot(
|
||||
&va[n][h][l][0],
|
||||
kvp,
|
||||
&gka[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
vm_dot(
|
||||
&ka[n][h][l][0],
|
||||
kvp,
|
||||
&gva[n][h][l][0],
|
||||
E,
|
||||
M
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"causal_dot_product",
|
||||
&causal_dot_product,
|
||||
"Compute the weighted sum of values but attending only to previous "
|
||||
"values."
|
||||
);
|
||||
m.def(
|
||||
"causal_dot_backward",
|
||||
&causal_dot_backward,
|
||||
"Compute the gradient of queries, keys and values given the gradient "
|
||||
"of causal_dot_product."
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda
|
||||
from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
causal_dot_product_cuda = causal_dot_backward_cuda = None
|
||||
|
||||
|
||||
class CausalDotProduct(torch.autograd.Function):
|
||||
"""Compute the weighted sum of values but attending only to previous
|
||||
values."""
|
||||
|
||||
dot = {
|
||||
# "cpu": causal_dot_product_cpu,
|
||||
"cuda": causal_dot_product_cuda
|
||||
}
|
||||
dot_backward = {
|
||||
# "cpu": causal_dot_backward_cpu,
|
||||
"cuda": causal_dot_backward_cuda
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, Q, K, V):
|
||||
# Save the inputs for the gradient computation
|
||||
ctx.save_for_backward(Q, K, V)
|
||||
|
||||
# Create the output tensor
|
||||
device = Q.device
|
||||
N, H, L, _ = Q.shape
|
||||
_, _, _, M = V.shape
|
||||
product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device)
|
||||
|
||||
# Actually perform the dot product
|
||||
CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
# breakpoint()
|
||||
# CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product)
|
||||
|
||||
return product
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
# Extract the saved tensors
|
||||
Q, K, V = ctx.saved_tensors
|
||||
|
||||
# Allocate memory for the gradients
|
||||
grad_Q = torch.zeros_like(Q)
|
||||
grad_K = torch.zeros_like(K)
|
||||
grad_V = torch.zeros_like(V)
|
||||
|
||||
# Actually compute the gradients
|
||||
CausalDotProduct.dot_backward[Q.device.type](
|
||||
Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V
|
||||
)
|
||||
|
||||
return grad_Q, grad_K, grad_V
|
||||
|
||||
|
||||
# Alias the autograd functions to python style snake case naming
|
||||
causal_dot_product = CausalDotProduct.apply
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
65
src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py
Normal file
@@ -0,0 +1,65 @@
|
||||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
|
||||
|
||||
|
||||
def get_last_arch_torch():
|
||||
arch = torch.cuda.get_arch_list()[-1]
|
||||
print(f"Found arch: {arch} from existing torch installation")
|
||||
return arch
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec
|
||||
)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
arch = get_last_arch_torch()
|
||||
sm_num = arch[-2:]
|
||||
cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100
|
||||
# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100
|
||||
# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090
|
||||
# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090
|
||||
# cc_flag = ['--generate-code=arch=compute_75,code=compute_75']
|
||||
|
||||
setup(
|
||||
name="causal_attention_cuda_cpp",
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
"causal_attention_cuda",
|
||||
[
|
||||
# 'causal_attention.cpp',
|
||||
"causal_attention_cuda.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"],
|
||||
"nvcc": append_nvcc_threads(
|
||||
["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag
|
||||
),
|
||||
},
|
||||
)
|
||||
],
|
||||
cmdclass={"build_ext": BuildExtension},
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
"""Linear LLaMA model implementation."""
|
||||
|
||||
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .attention import LolcatsLinearAttention
|
||||
from .configuration_linear_llama import LinearLlamaConfig
|
||||
|
||||
|
||||
class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
# Replace the attention layer with our custom attention
|
||||
self.self_attn = LolcatsLinearAttention(
|
||||
base_attn=self.self_attn, # type: ignore
|
||||
layer_idx=layer_idx,
|
||||
**config.attention_config,
|
||||
)
|
||||
|
||||
|
||||
class LinearLlamaModel(LlamaModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LinearLlamaConfig
|
||||
"""
|
||||
|
||||
config_class = LinearLlamaConfig
|
||||
base_model_prefix = "linear_llama"
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LinearLlamaDecoderLayer(config, layer_idx)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
||||
class LinearLlamaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = LinearLlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls,
|
||||
model: LlamaModel | LlamaForCausalLM,
|
||||
config: LinearLlamaConfig,
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
) -> "LinearLlamaForCausalLM":
|
||||
"""
|
||||
Initialize a LinearLlamaForCausalLM from a LlamaModel
|
||||
"""
|
||||
|
||||
# Handle LlamaForCausalLM
|
||||
if isinstance(model, LlamaForCausalLM):
|
||||
model = model.model
|
||||
|
||||
if config is None:
|
||||
raise ValueError("Missing config")
|
||||
|
||||
from axolotl.integrations.lolcats.linearize_attention import convert_attention
|
||||
|
||||
new_model = convert_attention(
|
||||
model,
|
||||
DictDefault(**config.attention_config),
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
|
||||
return new_model
|
||||
@@ -134,41 +134,39 @@ def get_attention(attention_type: str, **kwargs):
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .linear_attention import LolcatsLinearAttention
|
||||
from .linear_llama.attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .linear_attention import LolcatsTKWindowAttention
|
||||
from .linear_llama.attention import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .linear_attention import LolcatsSlidingWindowAttention
|
||||
from .linear_llama.attention import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .linear_attention.linear_window_attention_sw_linear import (
|
||||
LolcatsLinearSlidingWindowAttention,
|
||||
)
|
||||
from .linear_llama.attention import LolcatsLinearSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .linear_attention import LolcatsTKWindowLongAttention
|
||||
from .linear_llama.attention import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .linear_attention import LolcatsSlidingWindowLongAttention
|
||||
from .linear_llama.attention import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_attention import LolcatsWindowAttentionTKGen
|
||||
from .linear_llama.attention import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
@@ -186,30 +184,28 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .linear_attention import LinearAttentionTKWindowGenerationCache
|
||||
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .linear_attention import LinearAttentionTKWindowCache
|
||||
from .linear_llama.attention import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .linear_attention import LinearAttentionSlidingWindowCache
|
||||
from .linear_llama.attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .linear_attention import LinearAttentionSlidingWindowCache
|
||||
from .linear_llama.attention import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .linear_attention.linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
from .linear_llama.attention import LinearAttentionTKWindowGenerationCache
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
@@ -217,6 +213,6 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .linear_attention import LinearAttentionState
|
||||
from .linear_llama.attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
|
||||
Reference in New Issue
Block a user