diff --git a/src/axolotl/kernels/geglu.py b/src/axolotl/kernels/geglu.py index 0aa035c94..6acbea0d4 100644 --- a/src/axolotl/kernels/geglu.py +++ b/src/axolotl/kernels/geglu.py @@ -1,5 +1,4 @@ -""" -Module for definition of GEGLU Triton kernels. +"""Module for definition of GEGLU Triton kernels. See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). @@ -12,8 +11,6 @@ import torch import triton import triton.language as tl -SQRT_2_PI: tl.constexpr = 0.7978845608028654 # sqrt(2/π) - @triton.jit def _geglu_fwd_kernel(