From adeefc1991bea3fdd78abad4356cdd90d36f522a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 4 Feb 2025 19:29:42 +0700 Subject: [PATCH] feat: refactor into modeling code --- .../lolcats/linear_llama/__init__.py | 0 .../attention}/__init__.py | 1 + .../attention}/linear_attention.py | 2 +- .../attention}/linear_window_attention_sw.py | 0 .../linear_window_attention_sw_linear.py | 0 .../linear_window_attention_sw_long.py | 0 .../attention}/linear_window_attention_tk.py | 0 .../linear_window_attention_tk_gen.py | 0 .../linear_window_attention_tk_long.py | 0 .../attention}/utils.py | 0 .../configuration_linear_llama.py | 64 + .../lolcats/linear_llama/csrc/README.md | 30 + .../lolcats/linear_llama/csrc/__init__.py | 6 + .../linear_llama/csrc/causal_attention.cpp | 225 +++ .../linear_llama/csrc/causal_attention.py | 67 + .../csrc/causal_attention_cuda.cu | 1483 +++++++++++++++++ .../csrc/causal_attention_kv_cuda.cu | 1483 +++++++++++++++++ .../lolcats/linear_llama/csrc/setup.py | 65 + .../lolcats/linear_llama/model/__init__.py | 0 .../{ => linear_llama}/model/feature_map.py | 0 .../{ => linear_llama}/model/rotary.py | 0 .../linear_llama/modeling_linear_llama.py | 115 ++ .../lolcats/linearize_attention.py | 30 +- 23 files changed, 3553 insertions(+), 18 deletions(-) create mode 100644 src/axolotl/integrations/lolcats/linear_llama/__init__.py rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/__init__.py (89%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_attention.py (99%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_sw.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_sw_linear.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_sw_long.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_tk.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_tk_gen.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/linear_window_attention_tk_long.py (100%) rename src/axolotl/integrations/lolcats/{linear_attention => linear_llama/attention}/utils.py (100%) create mode 100644 src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/README.md create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/__init__.py create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.cpp create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.py create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_cuda.cu create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_kv_cuda.cu create mode 100644 src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py create mode 100644 src/axolotl/integrations/lolcats/linear_llama/model/__init__.py rename src/axolotl/integrations/lolcats/{ => linear_llama}/model/feature_map.py (100%) rename src/axolotl/integrations/lolcats/{ => linear_llama}/model/rotary.py (100%) create mode 100644 src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/lolcats/linear_attention/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py similarity index 89% rename from src/axolotl/integrations/lolcats/linear_attention/__init__.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py index 9f94414d6..3d815b7d0 100644 --- a/src/axolotl/integrations/lolcats/linear_attention/__init__.py +++ b/src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py @@ -6,6 +6,7 @@ from .linear_window_attention_sw import ( LinearAttentionSlidingWindowCache, LolcatsSlidingWindowAttention, ) +from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention from .linear_window_attention_tk import ( LinearAttentionTKWindowCache, diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py similarity index 99% rename from src/axolotl/integrations/lolcats/linear_attention/linear_attention.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py index 352459641..0d77f1e44 100644 --- a/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py @@ -16,7 +16,7 @@ except ImportError: fast_causal_dot_product = None from ..model.feature_map import init_feature_map, init_learned_kernel -from ..model.rotary import apply_rotary_pos_emb, get_rotary_embeddings +from ..model.rotary import apply_rotary_pos_emb from .utils import repeat_kv # ------------------- diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_linear.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_linear.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_linear.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_linear.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_long.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_long.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_long.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_long.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_gen.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_gen.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_gen.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_gen.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_long.py b/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_long.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_long.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_long.py diff --git a/src/axolotl/integrations/lolcats/linear_attention/utils.py b/src/axolotl/integrations/lolcats/linear_llama/attention/utils.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_attention/utils.py rename to src/axolotl/integrations/lolcats/linear_llama/attention/utils.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py new file mode 100644 index 000000000..31e81a274 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py @@ -0,0 +1,64 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Linear LLaMA model configuration""" + +from transformers import LlamaConfig + + +class LinearLlamaConfig(LlamaConfig): + """ + This is the configuration class to store the configuration of a [`LinearLlamaModel`]. + It is a modified LlamaConfig that includes additional parameters for linear attention. + + Args: + attention_config (`dict`): + Dictionary containing the configuration for linear attention mechanism. + Expected contents: + `feature_map` (`str`): + The type of feature map to use for linear attention. + `feature_map_kwargs` (`dict`): + Additional arguments for the feature map. + `learned_kernel` (`str`, *optional*): + Type of learned kernel to use, if any. + `learned_kernel_kwargs` (`dict`, *optional*): + Additional arguments for the learned kernel. + `tie_qk_kernels` (`bool`, *optional*, defaults to False): + Whether to tie query and key kernels. + `rotary_config` (`dict`, *optional*): + Configuration for rotary embeddings. + `train_attention` (`bool`, *optional*, defaults to False): + Whether to train attention to match softmax attention. + `remove_base_attn` (`bool`, *optional*, defaults to True): + Whether to remove base attention after initialization. + `mask_value` (`int`, *optional*, defaults to 0): + Value to use for masking. + `eps` (`float`, *optional*, defaults to 1e-12): + Epsilon value for numerical stability. + `fp32_attention` (`bool`, *optional*, defaults to False): + Whether to use fp32 precision for attention computation. + `track_state_grads` (`bool`, *optional*, defaults to False): + Whether to track gradients of attention states. + + **kwargs: + Additional arguments inherited from LlamaConfig. + """ + + model_type = "linear_llama" + + def __init__(self, attention_config: dict, **kwargs): + super().__init__(**kwargs) + + # Set default attention config if none provided + self.attention_config = attention_config diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/README.md b/src/axolotl/integrations/lolcats/linear_llama/csrc/README.md new file mode 100644 index 000000000..ee7425f05 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/README.md @@ -0,0 +1,30 @@ +# Causal linear attention CUDA kernel + +Usage: +```bash +cd src/axolotl/integrations/lolcats/linear_llama/csrc + +# Edit `setup.py` to point to the correct CUDA capabilities L40-44 +# nano setup.py + +# Build the CUDA kernel +python setup.py install +``` + +Reference: https://github.com/idiap/fast-transformers/ + +```bib +@inproceedings{katharopoulos_et_al_2020, + author = {Katharopoulos, A. and Vyas, A. and Pappas, N. and Fleuret, F.}, + title = {Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention}, + booktitle = {Proceedings of the International Conference on Machine Learning (ICML)}, + year = {2020} +} + +@article{vyas_et_al_2020, + author={Vyas, A. and Katharopoulos, A. and Fleuret, F.}, + title={Fast Transformers with Clustered Attention}, + booktitle = {Proceedings of the International Conference on Neural Information Processing Systems (NeurIPS)}, + year={2020} +} +``` diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/csrc/__init__.py new file mode 100644 index 000000000..e5caecd24 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# +from .causal_attention import causal_dot_product diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.cpp b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.cpp new file mode 100644 index 000000000..744844d59 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.cpp @@ -0,0 +1,225 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +#include + + +/** + * Compute a*b^T and save it into out. + * + * a \in R^A + * b \in R^B + */ +inline void vvt_dot(float *a, float *b, float *out, int A, int B) { + for (int i=0; i(); + auto ka = keys.accessor(); + auto va = values.accessor(); + auto pa = product.accessor(); + + #pragma omp parallel for collapse(2) + for (int n=0; n(); + for (int l=0; l(); + auto ka = keys.accessor(); + auto va = values.accessor(); + auto ga = grad_out.accessor(); + auto gqa = grad_queries.accessor(); + auto gka = grad_keys.accessor(); + auto gva = grad_values.accessor(); + + #pragma omp parallel for collapse(2) + for (int n=0; n(); + + // Compute the gradient wrt the queries + for (int l=0; l=0; l--) { + vvt_dot( + &qa[n][h][l][0], + &ga[n][h][l][0], + kvp, + E, + M + ); + vmt_dot( + &va[n][h][l][0], + kvp, + &gka[n][h][l][0], + E, + M + ); + vm_dot( + &ka[n][h][l][0], + kvp, + &gva[n][h][l][0], + E, + M + ); + } + } + } +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradient of queries, keys and values given the gradient " + "of causal_dot_product." + ); +} diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.py b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.py new file mode 100644 index 000000000..4ad289f07 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +import torch + +try: + from causal_attention_cuda import causal_dot_backward as causal_dot_backward_cuda + from causal_attention_cuda import causal_dot_product as causal_dot_product_cuda +except ImportError as e: + print(e) + causal_dot_product_cuda = causal_dot_backward_cuda = None + + +class CausalDotProduct(torch.autograd.Function): + """Compute the weighted sum of values but attending only to previous + values.""" + + dot = { + # "cpu": causal_dot_product_cpu, + "cuda": causal_dot_product_cuda + } + dot_backward = { + # "cpu": causal_dot_backward_cpu, + "cuda": causal_dot_backward_cuda + } + + @staticmethod + def forward(ctx, Q, K, V): + # Save the inputs for the gradient computation + ctx.save_for_backward(Q, K, V) + + # Create the output tensor + device = Q.device + N, H, L, _ = Q.shape + _, _, _, M = V.shape + product = torch.zeros((N, H, L, M), dtype=Q.dtype, device=device) + + # Actually perform the dot product + CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product) + # breakpoint() + # CausalDotProduct.dot[device.type](Q.data, K.data, V.data, product) + + return product + + @staticmethod + def backward(ctx, grad_out): + # Extract the saved tensors + Q, K, V = ctx.saved_tensors + + # Allocate memory for the gradients + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(K) + grad_V = torch.zeros_like(V) + + # Actually compute the gradients + CausalDotProduct.dot_backward[Q.device.type]( + Q.data, K.data, V.data, grad_out, grad_Q, grad_K, grad_V + ) + + return grad_Q, grad_K, grad_V + + +# Alias the autograd functions to python style snake case naming +causal_dot_product = CausalDotProduct.apply diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_cuda.cu b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_cuda.cu new file mode 100644 index 000000000..ab8f92c4f --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_cuda.cu @@ -0,0 +1,1483 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +// +// For modifications made inside namespace nvidia (authored by jdemouth): +// +// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +#include +#include +#include + +#define ENABLE_NVIDIA_OPTIMIZATIONS + +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS +namespace nvidia { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs). + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int div_up(int m, int n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int round_up(int m, int n) { + return div_up(m, n) * n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_params { + + // The output buffer. Dimensions [B, H, L, M]. + T *out; + + // The input Qs. Dimensions [B, H, L, E]. + const T *q; + // The input Ks. Dimensions [B, H, L, E]. + const T *k; + // The input Vs. Dimensions [B, H, L, M]. + const T *v; + + // The different dimensions. + int B, L, H, E, M; + + // The strides for the different tensors. + int q_stride_B, q_stride_H, q_stride_L; + int k_stride_B, k_stride_H, k_stride_L; + int v_stride_B, v_stride_H, v_stride_L; + int o_stride_B, o_stride_H, o_stride_L; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 > +__global__ __launch_bounds__(WARPS * THREADS_PER_WARP) +void lmha_low_occupancy_kernel(Lmha_params params) { + + // The number of threads per block. + constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP; + // The number of rows per thread. + constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP; + // The number of steps per iteration. + constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD; + + // Make sure E is a multiple of the warp size. + static_assert(E % THREADS_PER_WARP == 0, ""); + + // Shared memory to store V/O. + __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER]; + // Shared memory buffer to performance the reductions. + __shared__ float smem_reds[E * WARPS]; + + // The sequence processed by that block. + const int bi = blockIdx.z; + // The head processed by that block. + const int hi = blockIdx.y; + // The hidden cell in the V/output buffers. + const int vi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Decompose the block in warp/lane. + const int warp = tidx / THREADS_PER_WARP; + const int lane = tidx % THREADS_PER_WARP; + + // The base offset loaded by the thread in Q and K. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane; + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Position the warp at the beginning of the proper timestep. + if( GO_BACKWARD ) { + offset_q -= warp*COLS_PER_THREAD*params.q_stride_L; + offset_k -= warp*COLS_PER_THREAD*params.k_stride_L; + } else { + offset_q += warp*COLS_PER_THREAD*params.q_stride_L; + offset_k += warp*COLS_PER_THREAD*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // Is a given row valid? + int valid_qk[ROWS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) { + valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E; + } + + // The offset to the position loaded by the thread in V. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK. + if( GO_BACKWARD ) { + offset_v -= tidx*params.v_stride_L; + offset_o -= tidx*params.o_stride_L; + } else { + offset_v += tidx*params.v_stride_L; + offset_o += tidx*params.o_stride_L; + } + + // Determine the base pointer for V. + const float *ptr_v = ¶ms.v[offset_v]; + // The output pointer. + float *ptr_o = ¶ms.out[offset_o]; + + // The running KVs. + float running_kv[ROWS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] = 0.f; + } + + // Iterate over the timesteps. TODO: Use params.loop_count!!! + for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) { + + // Each thread loads a matrix of elements. + float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD]; + + // Trigger the memory loads for Q and K. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + + // For Q/K, each warp loads from various timesteps. + int ti = iter + warp*COLS_PER_THREAD; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid; + if( GO_BACKWARD ) { + valid = valid_qk[ri] && ti - ci >= 0; + } else { + valid = valid_qk[ri] && ti + ci < params.L; + } + + // The extra offset to add. + if( GO_BACKWARD ) { + offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L; + } else { + offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L; + } + + // Load Q/K if they are valid. + q[ri][ci] = valid ? ptr_q[offset_q] : 0.f; + k[ri][ci] = valid ? ptr_k[offset_k] : 0.f; + } + } + + // For the V tensor, we assign contiguous thread to different loads. So, ti is different. + int ti = iter + tidx; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid_vo = tidx < COLS_PER_ITER; + if( GO_BACKWARD ) { + valid_vo &= ti >= 0; + } else { + valid_vo &= ti < params.L; + } + + // Trigger the loads for V. + float ldg_v = valid_vo ? *ptr_v : 0.f; + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= COLS_PER_ITER*params.q_stride_L; + ptr_k -= COLS_PER_ITER*params.k_stride_L; + ptr_v -= COLS_PER_ITER*params.v_stride_L; + } else { + ptr_q += COLS_PER_ITER*params.q_stride_L; + ptr_k += COLS_PER_ITER*params.k_stride_L; + ptr_v += COLS_PER_ITER*params.v_stride_L; + } + + // Store to shared memory. + if( tidx < COLS_PER_ITER ) { + smem_v[tidx] = ldg_v; + } + + // Make sure V is in shared memory. + __syncthreads(); + + // Read V from shared memory. + float v[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + v[ci] = smem_v[warp*COLS_PER_THREAD + ci]; + } + + // Each thread computes local K*V products. + float kv[ROWS_PER_THREAD][COLS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] = 0.f; + } + } + + // Update the K*V^T product. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + kv[ri][ci] += k[ri][ci] * v[ci]; + } + } + + // We must perform the prefix sums within the thread-block. Start with the thread. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += kv[ri][ci-1]; + } + } + + // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread deals with one or more column(s) of the matrix. + constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK; + #pragma unroll + for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) { + if( idx < E ) { + float sum = smem_reds[idx]; + #pragma unroll + for( int jj = 1; jj < WARPS; ++jj ) { + smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E]; + } + } + } + + // Make sure the reductions are stored in shared memory. + __syncthreads(); + + // Each thread updates his partial products. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + float sum = running_kv[ri]; + if( warp > 0 ) { + sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP]; + } + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += sum; + } + } + + // Compute the partial output values for that thread. + float sum[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] = q[0][ci] * kv[0][ci]; + #pragma unroll + for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) { + sum[ci] += q[ri][ci] * kv[ri][ci]; + } + } + + // Run the parallel reductions inside the warp. + #pragma unroll + for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask); + } + } + + // Store the final output to shared memory. + if( lane == 0 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + smem_o[warp*COLS_PER_THREAD + ci] = sum[ci]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Store the output. + if( valid_vo ) { + *ptr_o = smem_o[tidx]; + } + + // Each thread updates his running kv. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP]; + } + + // Move to next location. + if( GO_BACKWARD ) { + ptr_o -= COLS_PER_ITER*params.o_stride_L; + } else { + ptr_o += COLS_PER_ITER*params.o_stride_L; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS > +int lmha_low_occupancy_(const Lmha_params ¶ms) { + + // Make sure we are not going to launch an invalid grid. + if( params.H > 65535 || params.B > 65535 ) { + return 1; + } + + // Prepare the grid and trigger the CUDA kernel. + dim3 grid; + grid.x = params.M; + grid.y = params.H; + grid.z = params.B; + lmha_low_occupancy_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD > +int lmha_low_occupancy_(const Lmha_params ¶ms, int blocks) { + if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else { + return lmha_low_occupancy_(params); + } + return 1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, typename Params > +static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { + int M = round_up(params.M, 4); + return 2*E + 2*M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +__global__ +void lmha_kernel(Lmha_params params) { + + // Make sure E is a multiple of 4. + static_assert(E % 4 == 0, ""); + + // The amount of shared memory per buffer (2 buffers for double-buffering). + const int smem_buffer_elts = smem_buffer_elts_(params); + // The M dimension for shared memory. + const int M = round_up(params.M, 4); + + // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts. + extern __shared__ float smem_[]; + + // The various shared memory buffers. + float *smem_q = &smem_[0*E]; + float *smem_k = &smem_[1*E]; + float *smem_v = &smem_[2*E]; + float *smem_o = &smem_[2*E + M]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // The offset to the position loaded by the thread in Q. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx; + // The offset to the position loaded by the thread in K. + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // The offset to the position loaded by the thread in V and O. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // Determine the base pointers for V. + const float *ptr_v = ¶ms.v[offset_v]; + + // Is it an active Q/K thread? + const int active_qk = tidx < params.E; + + // Trigger the memory loads for Q and K. + float ldg_q = 0.f, ldg_k = 0.f; + if( active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Is it an active V thread? + const int active_v = tidx < params.M; + + // Trigger the memory loads for V. + float ldg_v = 0.f; + if( active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = E / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the K*V^T values. + float4 kv[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The output pointer. + float *out_ptr = ¶ms.out[offset_o]; + + // Store to shared memory Q and K. + if( tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory V. All threads store valid values. + if( tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + + // The position of the thread in the V dimension. + int vo = tidx / THREADS_PER_HEAD; + int vi = tidx % THREADS_PER_HEAD; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads for Q and K. + if( !is_last && active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Trigger the next loads for V. + if( !is_last && active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from K. + float4 k[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + k[ii] = *reinterpret_cast(&smem_k[smem_curr*smem_buffer_elts + ki]); + } + + // Each thread loads a single V value. + float v = 0.f; + if( vo < params.M ) { + v = *reinterpret_cast(&smem_v[smem_curr*smem_buffer_elts + vo]); + } + + // Update the K*V^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii].x += k[ii].x * v; + kv[ii].y += k[ii].y * v; + kv[ii].z += k[ii].z * v; + kv[ii].w += k[ii].w * v; + } + + // Load the Q values from shared memory. + float4 q[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + q[ii] = *reinterpret_cast(&smem_q[smem_curr*smem_buffer_elts + qi]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += q[ii].x * kv[ii].x; + sum += q[ii].y * kv[ii].y; + sum += q[ii].z * kv[ii].z; + sum += q[ii].w * kv[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( vo < M && vi == 0 ) { + smem_o[smem_curr*smem_buffer_elts + vo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( active_v ) { + sum = smem_o[smem_curr*smem_buffer_elts + tidx]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( active_v ) { + *out_ptr = sum; + } + + // Move to next location. + if( GO_BACKWARD ) { + out_ptr -= params.o_stride_L; + } else { + out_ptr += params.o_stride_L; + } + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory for V. + if( !is_last && tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +int lmha_(const Lmha_params ¶ms) { + // The M dimension rounded up to 4. + int M = round_up(params.M, 4); + + // The number of threads in the block. + int block = round_up(max(E, M*THREADS_PER_HEAD), 32); + if( block > 512 || params.B > 65535 ) { + return 1; + } + + // Prepare the kernel. + dim3 grid(params.H, params.B); + size_t smem = smem_buffer_elts_(params)*2*sizeof(float); + lmha_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< bool GO_BACKWARD > +int lmha(const Lmha_params ¶ms) { + int blocks = params.B * params.H; + int res = 1; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + if( params.E <= 32 ) { + res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks); + } else if( params.E <= 64 ) { + res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks); + } else if( params.E <= 128 ) { + res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks); + } else if( params.E <= 256 ) { + res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks); + } + } else { + if( params.E <= 32 ) { + res = lmha_< 32, 1, GO_BACKWARD>(params); + } else if( params.E <= 48 ) { + res = lmha_< 48, 1, GO_BACKWARD>(params); + } else if( params.E <= 64 ) { + res = lmha_< 64, 1, GO_BACKWARD>(params); + } else if( params.E <= 128 ) { + res = lmha_<128, 2, GO_BACKWARD>(params); + } else if( params.E <= 256 ) { + res = lmha_<256, 4, GO_BACKWARD>(params); + } + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline void set_params(Lmha_params ¶ms, + const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor o) { + + // Define the pointers. + params.out = o.data_ptr(); + params.q = q.data_ptr(); + params.k = k.data_ptr(); + params.v = v.data_ptr(); + + // Define the strides. + params.q_stride_B = (int) q.stride(0); + params.q_stride_H = (int) q.stride(1); + params.q_stride_L = (int) q.stride(2); + params.k_stride_B = (int) k.stride(0); + params.k_stride_H = (int) k.stride(1); + params.k_stride_L = (int) k.stride(2); + params.v_stride_B = (int) v.stride(0); + params.v_stride_H = (int) v.stride(1); + params.v_stride_L = (int) v.stride(2); + params.o_stride_B = (int) o.stride(0); + params.o_stride_H = (int) o.stride(1); + params.o_stride_L = (int) o.stride(2); + + // Extract the dimensions. + int N = q.size(0); + int H = q.size(1); + int L = q.size(2); + int E = q.size(3); + int M = v.size(3); + + params.B = N; + params.L = L; + params.H = H; + params.E = E; + params.M = M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_fwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries.stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(product.stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // The structure of params. + Lmha_params params; + set_params(params, queries, keys, values, product); + + // Launch the kernel. + return lmha(params); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_bwd_params { + + // The output buffer for K. Dimensions [B, H, L, D]. + T *out_k; + // The output buffer for V. Dimensions [B, H, L, D]. + T *out_v; + + // The input Qs. Dimensions [B, H, L, D]. + const T *q; + // The input Ks. Dimensions [B, H, L, D]. + const T *k; + // The input Vs. Dimensions [B, H, L, D]. + const T *v; + // The input Gs. Dimensions [B, H, L, D]. + const T *g; + + // The dimensions. + int B, L, H, M, E; + + // The strides for the input tensors. + int q_stride_B, q_stride_L, q_stride_H; + int k_stride_B, k_stride_L, k_stride_H; + int v_stride_B, v_stride_L, v_stride_H; + int g_stride_B, g_stride_L, g_stride_H; + + // The strides for the outputs. + int out_k_stride_B, out_k_stride_L, out_k_stride_H; + int out_v_stride_B, out_v_stride_L, out_v_stride_H; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +__global__ __launch_bounds__(D*THREADS_PER_HEAD*2) +void lmha_bwd_kernel(Lmha_bwd_params params) { + + // Make sure D is a multiple of 4. + static_assert(D % 4 == 0, ""); + + // The shared memory buffers. + __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Split the threads into two slices. + int so = tidx / (D*THREADS_PER_HEAD); + int si = tidx % (D*THREADS_PER_HEAD); + + // The strides for B/L/H for the Q/G tensors. + int qg_stride_B, qg_stride_L, qg_stride_H; + if( so == 0 ) { + qg_stride_B = params.q_stride_B; + qg_stride_L = params.q_stride_L; + qg_stride_H = params.q_stride_H; + } else { + qg_stride_B = params.g_stride_B; + qg_stride_L = params.g_stride_L; + qg_stride_H = params.g_stride_H; + } + + // The strides for B/L/H for the K/V tensors. + int kv_stride_B, kv_stride_L, kv_stride_H; + if( so == 0 ) { + kv_stride_B = params.k_stride_B; + kv_stride_L = params.k_stride_L; + kv_stride_H = params.k_stride_H; + } else { + kv_stride_B = params.v_stride_B; + kv_stride_L = params.v_stride_L; + kv_stride_H = params.v_stride_H; + } + + // The hidden size. + int hidden_size_per_head = 0; + if( so == 0 ) { + hidden_size_per_head = params.E; + } else { + hidden_size_per_head = params.M; + } + + // Where to start reading from. + int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si; + int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_qg += (params.L-1)*qg_stride_L; + offset_kv += (params.L-1)*kv_stride_L; + + // Determine the base pointers for Q, K, V and G. + const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg]; + const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv]; + + // Is it an active thread? + const int active = si < hidden_size_per_head; + + // Trigger the memory loads for Q, K, V and G. + float ldg_qg = 0.f, ldg_kv = 0.f; + if( active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers (backward). + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = D / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the G*Q^T or Q^T*G values. + float4 gq[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The strides for B/L/H for the K/V tensors. + int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H; + if( so == 0 ) { + out_kv_stride_B = params.out_k_stride_B; + out_kv_stride_L = params.out_k_stride_L; + out_kv_stride_H = params.out_k_stride_H; + } else { + out_kv_stride_B = params.out_v_stride_B; + out_kv_stride_L = params.out_v_stride_L; + out_kv_stride_H = params.out_v_stride_H; + } + + // Where to start reading from. + int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_out_kv += (params.L-1)*out_kv_stride_L; + + // The output pointer. + float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv]; + + // Store to shared memory. + if( si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + + // The position of the thread in the output dimension. + int oo = si / THREADS_PER_HEAD % D; + int oi = si % THREADS_PER_HEAD * 4; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads. + if( !is_last && active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers. + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from G or Q. + float4 g[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi]; + g[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Each thread loads a single from Q or G value. + float q = smem_[smem_curr].qg[so*D + oo]; + + // Update the G*Q^T or Q*G^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii].x += g[ii].x * q; + gq[ii].y += g[ii].y * q; + gq[ii].z += g[ii].z * q; + gq[ii].w += g[ii].w * q; + } + + // Load the V or K values from shared memory. + float4 v[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi]; + v[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += v[ii].x * gq[ii].x; + sum += v[ii].y * gq[ii].y; + sum += v[ii].z * gq[ii].z; + sum += v[ii].w * gq[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( oi == 0 ) { + smem_[smem_curr].out_kv[so*D + oo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( si < hidden_size_per_head ) { + sum = smem_[smem_curr].out_kv[so*D + si]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( si < hidden_size_per_head ) { + *ptr_out_kv = sum; + } + + // Move to next location. + ptr_out_kv -= out_kv_stride_L; + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +int lmha_bwd_(const Lmha_bwd_params ¶ms) { + int block = D*THREADS_PER_HEAD*2; + if( block >= 1024 || params.B > 65535 ) { + return 1; + } + dim3 grid(params.H, params.B); + lmha_bwd_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const Lmha_bwd_params ¶ms) { + int blocks = params.B * params.H; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + return 1; + } + + int hidden_size_per_head = max(params.E, params.M); + int res = 1; + if( hidden_size_per_head <= 32 ) { + res = lmha_bwd_< 32, 1>(params); + } else if( hidden_size_per_head <= 64 ) { + res = lmha_bwd_< 64, 1>(params); + } else if( hidden_size_per_head <= 128 ) { + res = lmha_bwd_<128, 2>(params); + } else if( hidden_size_per_head <= 256 ) { + res = lmha_bwd_<256, 4>(params); + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries .stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(grad_out .stride(3) == 1); + assert(grad_queries.stride(3) == 1); + assert(grad_keys .stride(3) == 1); + assert(grad_values .stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // Gradient on Q. + + // The structure of params. + Lmha_params params; + set_params(params, grad_out, values, keys, grad_queries); + + // Launch the kernel. + int res = lmha(params); + if( res ) { + return res; + } + + // Gradient on K and V together. + + Lmha_bwd_params bwd_params; + bwd_params.out_k = grad_keys.data_ptr(); + bwd_params.out_v = grad_values.data_ptr(); + bwd_params.q = queries.data_ptr(); + bwd_params.k = keys.data_ptr(); + bwd_params.v = values.data_ptr(); + bwd_params.g = grad_out.data_ptr(); + + bwd_params.B = N; + bwd_params.L = L; + bwd_params.H = H; + bwd_params.E = E; + bwd_params.M = M; + + bwd_params.q_stride_B = queries.stride(0); + bwd_params.q_stride_H = queries.stride(1); + bwd_params.q_stride_L = queries.stride(2); + bwd_params.k_stride_B = keys.stride(0); + bwd_params.k_stride_H = keys.stride(1); + bwd_params.k_stride_L = keys.stride(2); + bwd_params.v_stride_B = values.stride(0); + bwd_params.v_stride_H = values.stride(1); + bwd_params.v_stride_L = values.stride(2); + bwd_params.g_stride_B = grad_out.stride(0); + bwd_params.g_stride_H = grad_out.stride(1); + bwd_params.g_stride_L = grad_out.stride(2); + + bwd_params.out_k_stride_B = grad_keys.stride(0); + bwd_params.out_k_stride_H = grad_keys.stride(1); + bwd_params.out_k_stride_L = grad_keys.stride(2); + bwd_params.out_v_stride_B = grad_values.stride(0); + bwd_params.out_v_stride_H = grad_values.stride(1); + bwd_params.out_v_stride_L = grad_values.stride(2); + + // Try to run the fused kernel. + int fallback = lmha_bwd(bwd_params); + + // If it failed, fallback on separate kernels for K and V. + if( fallback ) { + + // Gradient on K. + + // Launch the kernel. + set_params(params, values, grad_out, queries, grad_keys); + res = lmha(params); + if( res ) { + return res; + } + + // Gradient on V. + + // Launch the kernel. + set_params(params, keys, queries, grad_out, grad_values); + return lmha(params); + } + + // It worked... + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace nvidia +#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +typedef torch::PackedTensorAccessor32 float_accessor; + +#define E_BLOCK_SIZE 8 + +__global__ void causal_dot_product_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + float_accessor result, + const int N, + const int H, + const int L, + const int E, + const int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int e_start = blockIdx.x * E_BLOCK_SIZE; + int m = threadIdx.x % M; + + extern __shared__ float shared_mem[]; + float* shared_kv = shared_mem; + + for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) { + shared_kv[m + e_local * M] = 0; + } + + for (int t=0; t>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + product.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_product(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_fwd(queries, keys, values, product); +#else + int fallback = 1; +#endif + if( fallback ) { + causal_dot_product_(queries, keys, values, product); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define M_BLOCK_SIZE 4 + +// we need shared memory to store +// kv +// Backward direction +// kv_backwards +// Shared memory usage +__global__ void causal_dot_backward_query_key_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + const float_accessor grad_out, + float_accessor grad_queries, + float_accessor grad_keys, + int N, + int H, + int L, + int E, + int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int m_start = blockIdx.x * M_BLOCK_SIZE; + int e = threadIdx.x % E; + + extern __shared__ float shared_mem[]; + const int shared_kv_size = M_BLOCK_SIZE * E; + float* shared_kv = shared_mem; + float* shared_kv_bw = shared_mem + shared_kv_size; + + for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) { + shared_kv[m_local * E + e] = 0; + shared_kv_bw[m_local * E + e] = 0; + } + + for (int l=0; l>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_queries.packed_accessor32(), + grad_keys.packed_accessor32(), + N, H, L, E, M + ); + + const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE; + + dim3 blockDimv(M, 1, 1); + dim3 gridDimv(blocks_per_sequence_value, N, H); + const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float); + causal_dot_backward_value_kernel<<>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_keys.packed_accessor32(), + grad_values.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_backward(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_bwd(queries, + keys, + values, + grad_out, + grad_queries, + grad_keys, + grad_values); +#else + int fallback = 1; +#endif + if( fallback ) { + // Make sure that the gradient tensors are 0. This is needed because the + // bwd pass might have partially executed and filled in some values in + // grad_queries or grad_keys. + // + // This adds a small overhead every time we have to fall back to the old + // kernel for the backward pass. + grad_queries.zero_(); + grad_keys.zero_(); + causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradients for the causal dot product." + ); +} diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_kv_cuda.cu b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_kv_cuda.cu new file mode 100644 index 000000000..ab8f92c4f --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/causal_attention_kv_cuda.cu @@ -0,0 +1,1483 @@ +// +// Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +// Written by Angelos Katharopoulos , +// Apoorv Vyas +// + +// +// For modifications made inside namespace nvidia (authored by jdemouth): +// +// Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +// FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +// IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// + +#include +#include +#include + +#define ENABLE_NVIDIA_OPTIMIZATIONS + +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS +namespace nvidia { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr int LOW_OCCUPANCY_THRESHOLD = 40; // TODO: Make it HW specific (like 1/2 SMs). + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int div_up(int m, int n) { + return (m + n-1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ __host__ int round_up(int m, int n) { + return div_up(m, n) * n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_params { + + // The output buffer. Dimensions [B, H, L, M]. + T *out; + + // The input Qs. Dimensions [B, H, L, E]. + const T *q; + // The input Ks. Dimensions [B, H, L, E]. + const T *k; + // The input Vs. Dimensions [B, H, L, M]. + const T *v; + + // The different dimensions. + int B, L, H, E, M; + + // The strides for the different tensors. + int q_stride_B, q_stride_H, q_stride_L; + int k_stride_B, k_stride_H, k_stride_L; + int v_stride_B, v_stride_H, v_stride_L; + int o_stride_B, o_stride_H, o_stride_L; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS, int COLS_PER_THREAD = 4 > +__global__ __launch_bounds__(WARPS * THREADS_PER_WARP) +void lmha_low_occupancy_kernel(Lmha_params params) { + + // The number of threads per block. + constexpr int THREADS_PER_BLOCK = WARPS * THREADS_PER_WARP; + // The number of rows per thread. + constexpr int ROWS_PER_THREAD = E / THREADS_PER_WARP; + // The number of steps per iteration. + constexpr int COLS_PER_ITER = WARPS * COLS_PER_THREAD; + + // Make sure E is a multiple of the warp size. + static_assert(E % THREADS_PER_WARP == 0, ""); + + // Shared memory to store V/O. + __shared__ float smem_v[COLS_PER_ITER], smem_o[COLS_PER_ITER]; + // Shared memory buffer to performance the reductions. + __shared__ float smem_reds[E * WARPS]; + + // The sequence processed by that block. + const int bi = blockIdx.z; + // The head processed by that block. + const int hi = blockIdx.y; + // The hidden cell in the V/output buffers. + const int vi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Decompose the block in warp/lane. + const int warp = tidx / THREADS_PER_WARP; + const int lane = tidx % THREADS_PER_WARP; + + // The base offset loaded by the thread in Q and K. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + lane; + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + lane; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Position the warp at the beginning of the proper timestep. + if( GO_BACKWARD ) { + offset_q -= warp*COLS_PER_THREAD*params.q_stride_L; + offset_k -= warp*COLS_PER_THREAD*params.k_stride_L; + } else { + offset_q += warp*COLS_PER_THREAD*params.q_stride_L; + offset_k += warp*COLS_PER_THREAD*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // Is a given row valid? + int valid_qk[ROWS_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < ROWS_PER_THREAD; ++ii ) { + valid_qk[ii] = lane + ii*THREADS_PER_WARP < params.E; + } + + // The offset to the position loaded by the thread in V. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + vi; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + vi; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // We load/store a strided matrix of COLS_PER_ITER x OUTPUTS_PER_BLOCK. + if( GO_BACKWARD ) { + offset_v -= tidx*params.v_stride_L; + offset_o -= tidx*params.o_stride_L; + } else { + offset_v += tidx*params.v_stride_L; + offset_o += tidx*params.o_stride_L; + } + + // Determine the base pointer for V. + const float *ptr_v = ¶ms.v[offset_v]; + // The output pointer. + float *ptr_o = ¶ms.out[offset_o]; + + // The running KVs. + float running_kv[ROWS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] = 0.f; + } + + // Iterate over the timesteps. TODO: Use params.loop_count!!! + for( int iter = 0; iter < params.L; iter += COLS_PER_ITER ) { + + // Each thread loads a matrix of elements. + float q[ROWS_PER_THREAD][COLS_PER_THREAD], k[ROWS_PER_THREAD][COLS_PER_THREAD]; + + // Trigger the memory loads for Q and K. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + + // For Q/K, each warp loads from various timesteps. + int ti = iter + warp*COLS_PER_THREAD; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid; + if( GO_BACKWARD ) { + valid = valid_qk[ri] && ti - ci >= 0; + } else { + valid = valid_qk[ri] && ti + ci < params.L; + } + + // The extra offset to add. + if( GO_BACKWARD ) { + offset_q = ri*THREADS_PER_WARP - ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP - ci*params.k_stride_L; + } else { + offset_q = ri*THREADS_PER_WARP + ci*params.q_stride_L; + offset_k = ri*THREADS_PER_WARP + ci*params.k_stride_L; + } + + // Load Q/K if they are valid. + q[ri][ci] = valid ? ptr_q[offset_q] : 0.f; + k[ri][ci] = valid ? ptr_k[offset_k] : 0.f; + } + } + + // For the V tensor, we assign contiguous thread to different loads. So, ti is different. + int ti = iter + tidx; + if( GO_BACKWARD ) { + ti = params.L - 1 - ti; + } + + // Is it a valid access? + int valid_vo = tidx < COLS_PER_ITER; + if( GO_BACKWARD ) { + valid_vo &= ti >= 0; + } else { + valid_vo &= ti < params.L; + } + + // Trigger the loads for V. + float ldg_v = valid_vo ? *ptr_v : 0.f; + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= COLS_PER_ITER*params.q_stride_L; + ptr_k -= COLS_PER_ITER*params.k_stride_L; + ptr_v -= COLS_PER_ITER*params.v_stride_L; + } else { + ptr_q += COLS_PER_ITER*params.q_stride_L; + ptr_k += COLS_PER_ITER*params.k_stride_L; + ptr_v += COLS_PER_ITER*params.v_stride_L; + } + + // Store to shared memory. + if( tidx < COLS_PER_ITER ) { + smem_v[tidx] = ldg_v; + } + + // Make sure V is in shared memory. + __syncthreads(); + + // Read V from shared memory. + float v[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + v[ci] = smem_v[warp*COLS_PER_THREAD + ci]; + } + + // Each thread computes local K*V products. + float kv[ROWS_PER_THREAD][COLS_PER_THREAD]; + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] = 0.f; + } + } + + // Update the K*V^T product. + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + kv[ri][ci] += k[ri][ci] * v[ci]; + } + } + + // We must perform the prefix sums within the thread-block. Start with the thread. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + #pragma unroll + for( int ci = 1; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += kv[ri][ci-1]; + } + } + + // Store the partial sums to shared memory. Unless we have no inter-warp reduction to perform. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + smem_reds[warp*E + ri*THREADS_PER_WARP + lane] = kv[ri][COLS_PER_THREAD-1]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread deals with one or more column(s) of the matrix. + constexpr int SUMS_PER_THREAD = (E + THREADS_PER_BLOCK-1) / THREADS_PER_BLOCK; + #pragma unroll + for( int ii = 0, idx = tidx; ii < SUMS_PER_THREAD; ++ii, idx += THREADS_PER_BLOCK ) { + if( idx < E ) { + float sum = smem_reds[idx]; + #pragma unroll + for( int jj = 1; jj < WARPS; ++jj ) { + smem_reds[idx + jj*E] = sum += smem_reds[idx + jj*E]; + } + } + } + + // Make sure the reductions are stored in shared memory. + __syncthreads(); + + // Each thread updates his partial products. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + float sum = running_kv[ri]; + if( warp > 0 ) { + sum += smem_reds[(warp-1)*E + lane + ri*THREADS_PER_WARP]; + } + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + kv[ri][ci] += sum; + } + } + + // Compute the partial output values for that thread. + float sum[COLS_PER_THREAD]; + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] = q[0][ci] * kv[0][ci]; + #pragma unroll + for( int ri = 1; ri < ROWS_PER_THREAD; ++ri ) { + sum[ci] += q[ri][ci] * kv[ri][ci]; + } + } + + // Run the parallel reductions inside the warp. + #pragma unroll + for( int mask = THREADS_PER_WARP / 2; mask >= 1; mask /= 2 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + sum[ci] += __shfl_xor_sync(uint32_t(-1), sum[ci], mask); + } + } + + // Store the final output to shared memory. + if( lane == 0 ) { + #pragma unroll + for( int ci = 0; ci < COLS_PER_THREAD; ++ci ) { + smem_o[warp*COLS_PER_THREAD + ci] = sum[ci]; + } + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Store the output. + if( valid_vo ) { + *ptr_o = smem_o[tidx]; + } + + // Each thread updates his running kv. + #pragma unroll + for( int ri = 0; ri < ROWS_PER_THREAD; ++ri ) { + running_kv[ri] += smem_reds[(WARPS-1)*E + lane + ri*THREADS_PER_WARP]; + } + + // Move to next location. + if( GO_BACKWARD ) { + ptr_o -= COLS_PER_ITER*params.o_stride_L; + } else { + ptr_o += COLS_PER_ITER*params.o_stride_L; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD, int WARPS > +int lmha_low_occupancy_(const Lmha_params ¶ms) { + + // Make sure we are not going to launch an invalid grid. + if( params.H > 65535 || params.B > 65535 ) { + return 1; + } + + // Prepare the grid and trigger the CUDA kernel. + dim3 grid; + grid.x = params.M; + grid.y = params.H; + grid.z = params.B; + lmha_low_occupancy_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, bool GO_BACKWARD > +int lmha_low_occupancy_(const Lmha_params ¶ms, int blocks) { + if( params.M * blocks >= 8*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else if( params.M * blocks >= 4*LOW_OCCUPANCY_THRESHOLD ) { + return lmha_low_occupancy_(params); + } else { + return lmha_low_occupancy_(params); + } + return 1; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, typename Params > +static inline __device__ __host__ int smem_buffer_elts_(const Params ¶ms) { + int M = round_up(params.M, 4); + return 2*E + 2*M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +__global__ +void lmha_kernel(Lmha_params params) { + + // Make sure E is a multiple of 4. + static_assert(E % 4 == 0, ""); + + // The amount of shared memory per buffer (2 buffers for double-buffering). + const int smem_buffer_elts = smem_buffer_elts_(params); + // The M dimension for shared memory. + const int M = round_up(params.M, 4); + + // Shared memory to store Q, K and V. Size is 2*smem_buffer_elts. + extern __shared__ float smem_[]; + + // The various shared memory buffers. + float *smem_q = &smem_[0*E]; + float *smem_k = &smem_[1*E]; + float *smem_v = &smem_[2*E]; + float *smem_o = &smem_[2*E + M]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // The offset to the position loaded by the thread in Q. + int offset_q = bi*params.q_stride_B + hi*params.q_stride_H + tidx; + // The offset to the position loaded by the thread in K. + int offset_k = bi*params.k_stride_B + hi*params.k_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_q += (params.L-1)*params.q_stride_L; + offset_k += (params.L-1)*params.k_stride_L; + } + + // Determine the base pointers for Q and K. + const float *ptr_q = ¶ms.q[offset_q]; + const float *ptr_k = ¶ms.k[offset_k]; + + // The offset to the position loaded by the thread in V and O. + int offset_v = bi*params.v_stride_B + hi*params.v_stride_H + tidx; + int offset_o = bi*params.o_stride_B + hi*params.o_stride_H + tidx; + + // If we walk backward, account for the extra offset. + if( GO_BACKWARD ) { + offset_v += (params.L-1)*params.v_stride_L; + offset_o += (params.L-1)*params.o_stride_L; + } + + // Determine the base pointers for V. + const float *ptr_v = ¶ms.v[offset_v]; + + // Is it an active Q/K thread? + const int active_qk = tidx < params.E; + + // Trigger the memory loads for Q and K. + float ldg_q = 0.f, ldg_k = 0.f; + if( active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Is it an active V thread? + const int active_v = tidx < params.M; + + // Trigger the memory loads for V. + float ldg_v = 0.f; + if( active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = E / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the K*V^T values. + float4 kv[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The output pointer. + float *out_ptr = ¶ms.out[offset_o]; + + // Store to shared memory Q and K. + if( tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory V. All threads store valid values. + if( tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + + // The position of the thread in the V dimension. + int vo = tidx / THREADS_PER_HEAD; + int vi = tidx % THREADS_PER_HEAD; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads for Q and K. + if( !is_last && active_qk ) { + ldg_q = *ptr_q; + ldg_k = *ptr_k; + } + + // Trigger the next loads for V. + if( !is_last && active_v ) { + ldg_v = *ptr_v; + } + + // Move the load pointers. + if( GO_BACKWARD ) { + ptr_q -= params.q_stride_L; + ptr_k -= params.k_stride_L; + ptr_v -= params.v_stride_L; + } else { + ptr_q += params.q_stride_L; + ptr_k += params.k_stride_L; + ptr_v += params.v_stride_L; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from K. + float4 k[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int ki = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + k[ii] = *reinterpret_cast(&smem_k[smem_curr*smem_buffer_elts + ki]); + } + + // Each thread loads a single V value. + float v = 0.f; + if( vo < params.M ) { + v = *reinterpret_cast(&smem_v[smem_curr*smem_buffer_elts + vo]); + } + + // Update the K*V^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + kv[ii].x += k[ii].x * v; + kv[ii].y += k[ii].y * v; + kv[ii].z += k[ii].z * v; + kv[ii].w += k[ii].w * v; + } + + // Load the Q values from shared memory. + float4 q[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + int qi = tidx % THREADS_PER_HEAD * 4 + ii * THREADS_PER_HEAD * 4; + q[ii] = *reinterpret_cast(&smem_q[smem_curr*smem_buffer_elts + qi]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += q[ii].x * kv[ii].x; + sum += q[ii].y * kv[ii].y; + sum += q[ii].z * kv[ii].z; + sum += q[ii].w * kv[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( vo < M && vi == 0 ) { + smem_o[smem_curr*smem_buffer_elts + vo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( active_v ) { + sum = smem_o[smem_curr*smem_buffer_elts + tidx]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( active_v ) { + *out_ptr = sum; + } + + // Move to next location. + if( GO_BACKWARD ) { + out_ptr -= params.o_stride_L; + } else { + out_ptr += params.o_stride_L; + } + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && tidx < E ) { + smem_q[smem_curr*smem_buffer_elts + tidx] = ldg_q; + smem_k[smem_curr*smem_buffer_elts + tidx] = ldg_k; + } + + // Store to shared memory for V. + if( !is_last && tidx < M ) { + smem_v[smem_curr*smem_buffer_elts + tidx] = ldg_v; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int E, int THREADS_PER_HEAD, bool GO_BACKWARD > +int lmha_(const Lmha_params ¶ms) { + // The M dimension rounded up to 4. + int M = round_up(params.M, 4); + + // The number of threads in the block. + int block = round_up(max(E, M*THREADS_PER_HEAD), 32); + if( block > 512 || params.B > 65535 ) { + return 1; + } + + // Prepare the kernel. + dim3 grid(params.H, params.B); + size_t smem = smem_buffer_elts_(params)*2*sizeof(float); + lmha_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< bool GO_BACKWARD > +int lmha(const Lmha_params ¶ms) { + int blocks = params.B * params.H; + int res = 1; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + if( params.E <= 32 ) { + res = lmha_low_occupancy_< 32, GO_BACKWARD>(params, blocks); + } else if( params.E <= 64 ) { + res = lmha_low_occupancy_< 64, GO_BACKWARD>(params, blocks); + } else if( params.E <= 128 ) { + res = lmha_low_occupancy_<128, GO_BACKWARD>(params, blocks); + } else if( params.E <= 256 ) { + res = lmha_low_occupancy_<256, GO_BACKWARD>(params, blocks); + } + } else { + if( params.E <= 32 ) { + res = lmha_< 32, 1, GO_BACKWARD>(params); + } else if( params.E <= 48 ) { + res = lmha_< 48, 1, GO_BACKWARD>(params); + } else if( params.E <= 64 ) { + res = lmha_< 64, 1, GO_BACKWARD>(params); + } else if( params.E <= 128 ) { + res = lmha_<128, 2, GO_BACKWARD>(params); + } else if( params.E <= 256 ) { + res = lmha_<256, 4, GO_BACKWARD>(params); + } + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +inline void set_params(Lmha_params ¶ms, + const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor o) { + + // Define the pointers. + params.out = o.data_ptr(); + params.q = q.data_ptr(); + params.k = k.data_ptr(); + params.v = v.data_ptr(); + + // Define the strides. + params.q_stride_B = (int) q.stride(0); + params.q_stride_H = (int) q.stride(1); + params.q_stride_L = (int) q.stride(2); + params.k_stride_B = (int) k.stride(0); + params.k_stride_H = (int) k.stride(1); + params.k_stride_L = (int) k.stride(2); + params.v_stride_B = (int) v.stride(0); + params.v_stride_H = (int) v.stride(1); + params.v_stride_L = (int) v.stride(2); + params.o_stride_B = (int) o.stride(0); + params.o_stride_H = (int) o.stride(1); + params.o_stride_L = (int) o.stride(2); + + // Extract the dimensions. + int N = q.size(0); + int H = q.size(1); + int L = q.size(2); + int E = q.size(3); + int M = v.size(3); + + params.B = N; + params.L = L; + params.H = H; + params.E = E; + params.M = M; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_fwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries.stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(product.stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // The structure of params. + Lmha_params params; + set_params(params, queries, keys, values, product); + + // Launch the kernel. + return lmha(params); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< typename T > +struct Lmha_bwd_params { + + // The output buffer for K. Dimensions [B, H, L, D]. + T *out_k; + // The output buffer for V. Dimensions [B, H, L, D]. + T *out_v; + + // The input Qs. Dimensions [B, H, L, D]. + const T *q; + // The input Ks. Dimensions [B, H, L, D]. + const T *k; + // The input Vs. Dimensions [B, H, L, D]. + const T *v; + // The input Gs. Dimensions [B, H, L, D]. + const T *g; + + // The dimensions. + int B, L, H, M, E; + + // The strides for the input tensors. + int q_stride_B, q_stride_L, q_stride_H; + int k_stride_B, k_stride_L, k_stride_H; + int v_stride_B, v_stride_L, v_stride_H; + int g_stride_B, g_stride_L, g_stride_H; + + // The strides for the outputs. + int out_k_stride_B, out_k_stride_L, out_k_stride_H; + int out_v_stride_B, out_v_stride_L, out_v_stride_H; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +__global__ __launch_bounds__(D*THREADS_PER_HEAD*2) +void lmha_bwd_kernel(Lmha_bwd_params params) { + + // Make sure D is a multiple of 4. + static_assert(D % 4 == 0, ""); + + // The shared memory buffers. + __shared__ struct Smem { float qg[2*D], kv[2*D], out_kv[2*D]; } smem_[2]; + + // The index of the shared memory buffer (for double-buffering). + int smem_curr = 0; + + // The sequence processed by that block. + const int bi = blockIdx.y; + // The head processed by that block. + const int hi = blockIdx.x; + + // The linear index of the thread. + const int tidx = threadIdx.x; + + // Split the threads into two slices. + int so = tidx / (D*THREADS_PER_HEAD); + int si = tidx % (D*THREADS_PER_HEAD); + + // The strides for B/L/H for the Q/G tensors. + int qg_stride_B, qg_stride_L, qg_stride_H; + if( so == 0 ) { + qg_stride_B = params.q_stride_B; + qg_stride_L = params.q_stride_L; + qg_stride_H = params.q_stride_H; + } else { + qg_stride_B = params.g_stride_B; + qg_stride_L = params.g_stride_L; + qg_stride_H = params.g_stride_H; + } + + // The strides for B/L/H for the K/V tensors. + int kv_stride_B, kv_stride_L, kv_stride_H; + if( so == 0 ) { + kv_stride_B = params.k_stride_B; + kv_stride_L = params.k_stride_L; + kv_stride_H = params.k_stride_H; + } else { + kv_stride_B = params.v_stride_B; + kv_stride_L = params.v_stride_L; + kv_stride_H = params.v_stride_H; + } + + // The hidden size. + int hidden_size_per_head = 0; + if( so == 0 ) { + hidden_size_per_head = params.E; + } else { + hidden_size_per_head = params.M; + } + + // Where to start reading from. + int offset_qg = bi*qg_stride_B + hi*qg_stride_H + si; + int offset_kv = bi*kv_stride_B + hi*kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_qg += (params.L-1)*qg_stride_L; + offset_kv += (params.L-1)*kv_stride_L; + + // Determine the base pointers for Q, K, V and G. + const float *ptr_qg = &(so == 0 ? params.q : params.g)[offset_qg]; + const float *ptr_kv = &(so == 0 ? params.k : params.v)[offset_kv]; + + // Is it an active thread? + const int active = si < hidden_size_per_head; + + // Trigger the memory loads for Q, K, V and G. + float ldg_qg = 0.f, ldg_kv = 0.f; + if( active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers (backward). + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // The number of FLOAT4s per head. + constexpr int FLOAT4s_PER_HEAD = D / 4; + // The number of FLOAT4s per thread. + constexpr int FLOAT4s_PER_THREAD = FLOAT4s_PER_HEAD / THREADS_PER_HEAD; + + // The storage for the G*Q^T or Q^T*G values. + float4 gq[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii] = make_float4(0.f, 0.f, 0.f, 0.f); + } + + // The strides for B/L/H for the K/V tensors. + int out_kv_stride_B, out_kv_stride_L, out_kv_stride_H; + if( so == 0 ) { + out_kv_stride_B = params.out_k_stride_B; + out_kv_stride_L = params.out_k_stride_L; + out_kv_stride_H = params.out_k_stride_H; + } else { + out_kv_stride_B = params.out_v_stride_B; + out_kv_stride_L = params.out_v_stride_L; + out_kv_stride_H = params.out_v_stride_H; + } + + // Where to start reading from. + int offset_out_kv = bi*out_kv_stride_B + hi*out_kv_stride_H + si; + + // We walk backward, account for the extra offset. + offset_out_kv += (params.L-1)*out_kv_stride_L; + + // The output pointer. + float *ptr_out_kv = &(so == 0 ? params.out_k : params.out_v)[offset_out_kv]; + + // Store to shared memory. + if( si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + + // The position of the thread in the output dimension. + int oo = si / THREADS_PER_HEAD % D; + int oi = si % THREADS_PER_HEAD * 4; + + // Iterate over the timesteps. + for( int ti = 0; ti < params.L; ++ti ) { + + // Is it the last iteration? + int is_last = ti == params.L - 1; + + // Trigger the next loads. + if( !is_last && active ) { + ldg_qg = *ptr_qg; + ldg_kv = *ptr_kv; + } + + // Move the load pointers. + ptr_qg -= qg_stride_L; + ptr_kv -= kv_stride_L; + + // Make sure the data is in shared memory. + __syncthreads(); + + // Each thread loads 4 values from G or Q. + float4 g[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].qg[(so^1)*D + oi]; + g[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Each thread loads a single from Q or G value. + float q = smem_[smem_curr].qg[so*D + oo]; + + // Update the G*Q^T or Q*G^T product. + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + gq[ii].x += g[ii].x * q; + gq[ii].y += g[ii].y * q; + gq[ii].z += g[ii].z * q; + gq[ii].w += g[ii].w * q; + } + + // Load the V or K values from shared memory. + float4 v[FLOAT4s_PER_THREAD]; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + float *smem_ptr = &smem_[smem_curr].kv[(so^1)*D + oi]; + v[ii] = *reinterpret_cast(&smem_ptr[ii*THREADS_PER_HEAD*4]); + } + + // Compute the partial output value for that thread. + float sum = 0.f; + #pragma unroll + for( int ii = 0; ii < FLOAT4s_PER_THREAD; ++ii ) { + sum += v[ii].x * gq[ii].x; + sum += v[ii].y * gq[ii].y; + sum += v[ii].z * gq[ii].z; + sum += v[ii].w * gq[ii].w; + } + + // Finalize the computation of the sum (if we have more than 1 thread per head). + if( THREADS_PER_HEAD > 1 ) { + + // Finalize the sum for each head. + #pragma unroll + for( int mask = THREADS_PER_HEAD / 2; mask >= 1; mask /= 2 ) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Store to shared memory. + if( oi == 0 ) { + smem_[smem_curr].out_kv[so*D + oo] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // Active threads read the data to store. + if( si < hidden_size_per_head ) { + sum = smem_[smem_curr].out_kv[so*D + si]; + } + + } // THREADS_PER_HEAD > 1. + + // Store the output. All the threads are active. + if( si < hidden_size_per_head ) { + *ptr_out_kv = sum; + } + + // Move to next location. + ptr_out_kv -= out_kv_stride_L; + + // Move the shared memory buffer. + smem_curr = (smem_curr + 1) % 2; + + // Store to shared memory for Q and K. + if( !is_last && si < D ) { + smem_[smem_curr].qg[so*D + si] = ldg_qg; + smem_[smem_curr].kv[so*D + si] = ldg_kv; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< int D, int THREADS_PER_HEAD > +int lmha_bwd_(const Lmha_bwd_params ¶ms) { + int block = D*THREADS_PER_HEAD*2; + if( block >= 1024 || params.B > 65535 ) { + return 1; + } + dim3 grid(params.H, params.B); + lmha_bwd_kernel<<>>(params); + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const Lmha_bwd_params ¶ms) { + int blocks = params.B * params.H; + if( blocks < LOW_OCCUPANCY_THRESHOLD ) { + return 1; + } + + int hidden_size_per_head = max(params.E, params.M); + int res = 1; + if( hidden_size_per_head <= 32 ) { + res = lmha_bwd_< 32, 1>(params); + } else if( hidden_size_per_head <= 64 ) { + res = lmha_bwd_< 64, 1>(params); + } else if( hidden_size_per_head <= 128 ) { + res = lmha_bwd_<128, 2>(params); + } else if( hidden_size_per_head <= 256 ) { + res = lmha_bwd_<256, 4>(params); + } + return res; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int lmha_bwd(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { + + // Make sure that we are using the correct GPU device + torch::DeviceGuard _guard(queries.device()); + + // Make sure the inner-most dimension of the tensors is packed. + assert(queries .stride(3) == 1); + assert(keys .stride(3) == 1); + assert(values .stride(3) == 1); + assert(grad_out .stride(3) == 1); + assert(grad_queries.stride(3) == 1); + assert(grad_keys .stride(3) == 1); + assert(grad_values .stride(3) == 1); + + // Extract the dimensions. + int N = queries.size(0); + int H = queries.size(1); + int L = queries.size(2); + int E = queries.size(3); + int M = values.size (3); + + // Gradient on Q. + + // The structure of params. + Lmha_params params; + set_params(params, grad_out, values, keys, grad_queries); + + // Launch the kernel. + int res = lmha(params); + if( res ) { + return res; + } + + // Gradient on K and V together. + + Lmha_bwd_params bwd_params; + bwd_params.out_k = grad_keys.data_ptr(); + bwd_params.out_v = grad_values.data_ptr(); + bwd_params.q = queries.data_ptr(); + bwd_params.k = keys.data_ptr(); + bwd_params.v = values.data_ptr(); + bwd_params.g = grad_out.data_ptr(); + + bwd_params.B = N; + bwd_params.L = L; + bwd_params.H = H; + bwd_params.E = E; + bwd_params.M = M; + + bwd_params.q_stride_B = queries.stride(0); + bwd_params.q_stride_H = queries.stride(1); + bwd_params.q_stride_L = queries.stride(2); + bwd_params.k_stride_B = keys.stride(0); + bwd_params.k_stride_H = keys.stride(1); + bwd_params.k_stride_L = keys.stride(2); + bwd_params.v_stride_B = values.stride(0); + bwd_params.v_stride_H = values.stride(1); + bwd_params.v_stride_L = values.stride(2); + bwd_params.g_stride_B = grad_out.stride(0); + bwd_params.g_stride_H = grad_out.stride(1); + bwd_params.g_stride_L = grad_out.stride(2); + + bwd_params.out_k_stride_B = grad_keys.stride(0); + bwd_params.out_k_stride_H = grad_keys.stride(1); + bwd_params.out_k_stride_L = grad_keys.stride(2); + bwd_params.out_v_stride_B = grad_values.stride(0); + bwd_params.out_v_stride_H = grad_values.stride(1); + bwd_params.out_v_stride_L = grad_values.stride(2); + + // Try to run the fused kernel. + int fallback = lmha_bwd(bwd_params); + + // If it failed, fallback on separate kernels for K and V. + if( fallback ) { + + // Gradient on K. + + // Launch the kernel. + set_params(params, values, grad_out, queries, grad_keys); + res = lmha(params); + if( res ) { + return res; + } + + // Gradient on V. + + // Launch the kernel. + set_params(params, keys, queries, grad_out, grad_values); + return lmha(params); + } + + // It worked... + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace nvidia +#endif // #ifdef ENABLE_NVIDIA_OPTIMIZATIONS + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +typedef torch::PackedTensorAccessor32 float_accessor; + +#define E_BLOCK_SIZE 8 + +__global__ void causal_dot_product_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + float_accessor result, + const int N, + const int H, + const int L, + const int E, + const int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int e_start = blockIdx.x * E_BLOCK_SIZE; + int m = threadIdx.x % M; + + extern __shared__ float shared_mem[]; + float* shared_kv = shared_mem; + + for (int e_local = 0; e_local < E_BLOCK_SIZE && e_local + e_start < E; e_local++) { + shared_kv[m + e_local * M] = 0; + } + + for (int t=0; t>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + product.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_product(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + torch::Tensor product) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_fwd(queries, keys, values, product); +#else + int fallback = 1; +#endif + if( fallback ) { + causal_dot_product_(queries, keys, values, product); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define M_BLOCK_SIZE 4 + +// we need shared memory to store +// kv +// Backward direction +// kv_backwards +// Shared memory usage +__global__ void causal_dot_backward_query_key_kernel( + const float_accessor queries, + const float_accessor keys, + const float_accessor values, + const float_accessor grad_out, + float_accessor grad_queries, + float_accessor grad_keys, + int N, + int H, + int L, + int E, + int M +) { + int n = blockIdx.y; + int h = blockIdx.z; + + int m_start = blockIdx.x * M_BLOCK_SIZE; + int e = threadIdx.x % E; + + extern __shared__ float shared_mem[]; + const int shared_kv_size = M_BLOCK_SIZE * E; + float* shared_kv = shared_mem; + float* shared_kv_bw = shared_mem + shared_kv_size; + + for (int m_local = 0; m_local < M_BLOCK_SIZE && m_local + m_start < M; m_local++) { + shared_kv[m_local * E + e] = 0; + shared_kv_bw[m_local * E + e] = 0; + } + + for (int l=0; l>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_queries.packed_accessor32(), + grad_keys.packed_accessor32(), + N, H, L, E, M + ); + + const int blocks_per_sequence_value = (E + E_BLOCK_SIZE - 1) / E_BLOCK_SIZE; + + dim3 blockDimv(M, 1, 1); + dim3 gridDimv(blocks_per_sequence_value, N, H); + const int shared_mem_v_backward = E_BLOCK_SIZE * M * sizeof(float); + causal_dot_backward_value_kernel<<>>( + queries.packed_accessor32(), + keys.packed_accessor32(), + values.packed_accessor32(), + grad_out.packed_accessor32(), + grad_keys.packed_accessor32(), + grad_values.packed_accessor32(), + N, H, L, E, M + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void causal_dot_backward(const torch::Tensor queries, + const torch::Tensor keys, + const torch::Tensor values, + const torch::Tensor grad_out, + torch::Tensor grad_queries, + torch::Tensor grad_keys, + torch::Tensor grad_values) { +#ifdef ENABLE_NVIDIA_OPTIMIZATIONS + int fallback = nvidia::lmha_bwd(queries, + keys, + values, + grad_out, + grad_queries, + grad_keys, + grad_values); +#else + int fallback = 1; +#endif + if( fallback ) { + // Make sure that the gradient tensors are 0. This is needed because the + // bwd pass might have partially executed and filled in some values in + // grad_queries or grad_keys. + // + // This adds a small overhead every time we have to fall back to the old + // kernel for the backward pass. + grad_queries.zero_(); + grad_keys.zero_(); + causal_dot_backward_(queries, keys, values, grad_out, grad_queries, grad_keys, grad_values); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "causal_dot_product", + &causal_dot_product, + "Compute the weighted sum of values but attending only to previous " + "values." + ); + m.def( + "causal_dot_backward", + &causal_dot_backward, + "Compute the gradients for the causal dot product." + ); +} diff --git a/src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py b/src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py new file mode 100644 index 000000000..554554080 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/csrc/setup.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/ +# Written by Angelos Katharopoulos , +# Apoorv Vyas +# + +import subprocess # nosec + +import torch +from setuptools import setup +from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension + + +def get_last_arch_torch(): + arch = torch.cuda.get_arch_list()[-1] + print(f"Found arch: {arch} from existing torch installation") + return arch + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True # nosec + ) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + return raw_output, bare_metal_major, bare_metal_minor + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +arch = get_last_arch_torch() +sm_num = arch[-2:] +cc_flag = ["--generate-code=arch=compute_90,code=compute_90"] # for H100 +# cc_flag = ['--generate-code=arch=compute_80,code=compute_80'] # for A100 +# cc_flag = ['--generate-code=arch=compute_89,code=compute_89'] # for RTX 6000, 4090 +# cc_flag = ['--generate-code=arch=compute_86,code=compute_86'] # for A6000, 3090 +# cc_flag = ['--generate-code=arch=compute_75,code=compute_75'] + +setup( + name="causal_attention_cuda_cpp", + ext_modules=[ + CUDAExtension( + "causal_attention_cuda", + [ + # 'causal_attention.cpp', + "causal_attention_cuda.cu", + ], + extra_compile_args={ + "cxx": ["-O3"], + "nvcc": append_nvcc_threads( + ["-O3", "-lineinfo", "--use_fast_math", "-std=c++17"] + cc_flag + ), + }, + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/src/axolotl/integrations/lolcats/linear_llama/model/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/lolcats/model/feature_map.py b/src/axolotl/integrations/lolcats/linear_llama/model/feature_map.py similarity index 100% rename from src/axolotl/integrations/lolcats/model/feature_map.py rename to src/axolotl/integrations/lolcats/linear_llama/model/feature_map.py diff --git a/src/axolotl/integrations/lolcats/model/rotary.py b/src/axolotl/integrations/lolcats/linear_llama/model/rotary.py similarity index 100% rename from src/axolotl/integrations/lolcats/model/rotary.py rename to src/axolotl/integrations/lolcats/linear_llama/model/rotary.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py new file mode 100644 index 000000000..937735e0d --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +"""Linear LLaMA model implementation.""" + + +from torch import nn +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, +) + +from axolotl.utils.dict import DictDefault + +from .attention import LolcatsLinearAttention +from .configuration_linear_llama import LinearLlamaConfig + + +class LinearLlamaDecoderLayer(LlamaDecoderLayer): + """ + Modified LlamaDecoderLayer that uses LinearAttention instead of standard attention. + """ + + def __init__(self, config: LinearLlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + # Replace the attention layer with our custom attention + self.self_attn = LolcatsLinearAttention( + base_attn=self.self_attn, # type: ignore + layer_idx=layer_idx, + **config.attention_config, + ) + + +class LinearLlamaModel(LlamaModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LinearLlamaDecoderLayer`] + + Args: + config: LinearLlamaConfig + """ + + config_class = LinearLlamaConfig + base_model_prefix = "linear_llama" + + def __init__(self, config: LinearLlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + LinearLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + +class LinearLlamaForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super().__init__(config) + self.model = LinearLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_llama( + cls, + model: LlamaModel | LlamaForCausalLM, + config: LinearLlamaConfig, + train_attention: bool = False, + remove_base_attn: bool = True, + ) -> "LinearLlamaForCausalLM": + """ + Initialize a LinearLlamaForCausalLM from a LlamaModel + """ + + # Handle LlamaForCausalLM + if isinstance(model, LlamaForCausalLM): + model = model.model + + if config is None: + raise ValueError("Missing config") + + from axolotl.integrations.lolcats.linearize_attention import convert_attention + + new_model = convert_attention( + model, + DictDefault(**config.attention_config), + train_attention=train_attention, + remove_base_attn=remove_base_attn, + ) + + return new_model diff --git a/src/axolotl/integrations/lolcats/linearize_attention.py b/src/axolotl/integrations/lolcats/linearize_attention.py index c3a065ff2..075c59135 100644 --- a/src/axolotl/integrations/lolcats/linearize_attention.py +++ b/src/axolotl/integrations/lolcats/linearize_attention.py @@ -134,41 +134,39 @@ def get_attention(attention_type: str, **kwargs): kwargs["attention_type"] = attention_type if attention_type == "lolcats_llama": - from .linear_attention import LolcatsLinearAttention + from .linear_llama.attention import LolcatsLinearAttention return partial(LolcatsLinearAttention, **kwargs) elif attention_type == "lolcats_llama_window_tk": - from .linear_attention import LolcatsTKWindowAttention + from .linear_llama.attention import LolcatsTKWindowAttention return partial(LolcatsTKWindowAttention, **kwargs) elif attention_type == "lolcats_llama_window_sw": - from .linear_attention import LolcatsSlidingWindowAttention + from .linear_llama.attention import LolcatsSlidingWindowAttention return partial(LolcatsSlidingWindowAttention, **kwargs) elif attention_type == "lolcats_llama_window_sw_linear": - from .linear_attention.linear_window_attention_sw_linear import ( - LolcatsLinearSlidingWindowAttention, - ) + from .linear_llama.attention import LolcatsLinearSlidingWindowAttention return partial(LolcatsLinearSlidingWindowAttention, **kwargs) # Experimental chunked linear attentions below elif attention_type == "lolcats_long_llama_window_tk": - from .linear_attention import LolcatsTKWindowLongAttention + from .linear_llama.attention import LolcatsTKWindowLongAttention return partial(LolcatsTKWindowLongAttention, **kwargs) elif attention_type == "lolcats_long_llama_window_sw": - from .linear_attention import LolcatsSlidingWindowLongAttention + from .linear_llama.attention import LolcatsSlidingWindowLongAttention return partial(LolcatsSlidingWindowLongAttention, **kwargs) # TK generation build (requires Thunderkittens) elif attention_type == "lolcats_llama_window_tk_gen": - from .linear_attention import LolcatsWindowAttentionTKGen + from .linear_llama.attention import LolcatsWindowAttentionTKGen return partial(LolcatsWindowAttentionTKGen, **kwargs) @@ -186,30 +184,28 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None): # LOG.info(f'Returning attention cache based on attention_type == {attention_type}') elif "lolcats_llama_window_tk_gen" in attention_type: - from .linear_attention import LinearAttentionTKWindowGenerationCache + from .linear_llama.attention import LinearAttentionTKWindowGenerationCache return LinearAttentionTKWindowGenerationCache() elif "llama_window_tk" in attention_type: - from .linear_attention import LinearAttentionTKWindowCache + from .linear_llama.attention import LinearAttentionTKWindowCache return LinearAttentionTKWindowCache() elif "llama_window_sw" in attention_type: - from .linear_attention import LinearAttentionSlidingWindowCache + from .linear_llama.attention import LinearAttentionSlidingWindowCache return LinearAttentionSlidingWindowCache() elif "llama_window_sw_linear" in attention_type: - from .linear_attention import LinearAttentionSlidingWindowCache + from .linear_llama.attention import LinearAttentionSlidingWindowCache return LinearAttentionSlidingWindowCache() # TK generation build (requires Thunderkittens) elif attention_type == "lolcats_llama_window_tk_gen": - from .linear_attention.linear_window_attention_tk_gen import ( - LinearAttentionTKWindowGenerationCache, - ) + from .linear_llama.attention import LinearAttentionTKWindowGenerationCache return LinearAttentionTKWindowGenerationCache() @@ -217,6 +213,6 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None): return past_key_values else: - from .linear_attention import LinearAttentionState + from .linear_llama.attention import LinearAttentionState return LinearAttentionState()