Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
f8acc72dd8 proof of concept for sage attention 2024-11-22 14:47:19 -05:00
6 changed files with 1143 additions and 0 deletions

View File

@@ -0,0 +1,361 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Optional
import torch
from torch.autograd import Function
from .triton.attn_qk_int8_per_block_causal_varlen import (
backward as sageattn_varlen_backward,
)
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
from .triton.quant_per_block_varlen import (
per_block_int8 as per_block_int8_varlen_triton,
)
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
**kwargs: Any,
) -> torch.Tensor:
"""
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
cu_seqlens_q : torch.Tensor
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
cu_seqlens_k : torch.Tensor
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
max_seqlen_q : int
The maximum sequence length for the query tensor in the batch.
max_seqlen_k : int
The maximum sequence length for the key and value tensors in the batch.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
Returns
-------
torch.Tensor
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [
torch.float16,
torch.bfloat16,
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
head_dim = q.size(-1)
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
assert (
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
), "Last dim of qkv must be contiguous."
assert (
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if smooth_k:
km = k.mean(
dim=0, keepdim=True
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
k -= km
(
q_int8,
q_scale,
k_int8,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
) = per_block_int8_varlen_triton(
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
)
o = attn_true_varlen(
q_int8,
k_int8,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=dtype,
)
return o
class SageAttentionFunction(Function):
@staticmethod
def forward(
ctx,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
attn_mask: Optional[Tensor], mask tensor
dropout_p: float, dropout probability
is_causal: bool, whether to apply causal masking
scale: Optional[float], scaling factor for attention scores
"""
# Ensure inputs are contiguous
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# Handle default scale
if scale is None:
scale = 1.0 / (query.size(-1) ** 0.5)
# Save parameters needed for backward
ctx.scale = scale
ctx.is_causal = is_causal
ctx.dropout_p = dropout_p
ctx.attn_mask = attn_mask
# Prepare cumulative sequence lengths and max sequence lengths
# Assuming batch sizes are consistent across query, key, and value
batch_size, num_heads, seq_len_q, head_dim = query.shape
seq_len_k = key.shape[2]
# Flatten batch and head dimensions
q = query.view(
-1, seq_len_q, head_dim
) # [batch_size * num_heads, seq_len_q, head_dim]
k = key.view(-1, seq_len_k, head_dim)
v = value.view(-1, seq_len_k, head_dim)
# Create cumulative sequence lengths
cu_seqlens_q = torch.arange(
0,
(batch_size * num_heads + 1) * seq_len_q,
seq_len_q,
dtype=torch.int32,
device=query.device,
)
cu_seqlens_k = torch.arange(
0,
(batch_size * num_heads + 1) * seq_len_k,
seq_len_k,
dtype=torch.int32,
device=key.device,
)
max_seqlen_q = seq_len_q
max_seqlen_k = seq_len_k
# Call your custom per-block int8 quantization function
(
q_int8,
q_scale,
k_int8,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
) = per_block_int8_varlen_triton(
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
)
# Call your custom attention function
if is_causal:
output = attn_true_varlen(
q_int8,
k_int8,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=query.dtype,
)
else:
raise NotImplementedError("Non-causal attention is not implemented yet.")
# Reshape output to match the expected shape
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
# Save tensors for backward
ctx.save_for_backward(
query,
key,
value,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output,
)
return output
@staticmethod
def backward(ctx, grad_output):
(
query,
key,
value,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output,
) = ctx.saved_tensors
scale = ctx.scale
is_causal = ctx.is_causal
dropout_p = ctx.dropout_p
attn_mask = ctx.attn_mask
# Flatten batch and head dimensions
batch_size, num_heads, seq_len_q, head_dim = query.shape
seq_len_k = key.shape[2]
grad_output = grad_output.contiguous()
do = grad_output.view(-1, seq_len_q, head_dim)
# Compute gradients w.r.t. q, k, v
dq, dk, dv = sageattn_varlen_backward(
do,
query.view(-1, seq_len_q, head_dim),
key.view(-1, seq_len_k, head_dim),
value.view(-1, seq_len_k, head_dim),
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
scale,
is_causal,
)
# Reshape gradients to match the input shapes
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
# Handle optional arguments
d_attn_mask = None # Assuming attn_mask does not require gradients
d_dropout_p = (
None # Dropout probability is a hyperparameter, typically not optimized
)
d_is_causal = None # Not differentiable
d_scale = None # If scale is a tensor and requires grad, compute its gradient
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""
Custom scaled dot product attention using SageAttentionFunction.
"""
return SageAttentionFunction.apply(
query, key, value, attn_mask, dropout_p, is_causal, scale
)
def monkeypatch_sdp_w_sage_attention():
"""
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
"""
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -0,0 +1,622 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import torch
import triton
import triton.language as tl
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H: tl.constexpr,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += (lo // BLOCK_N) * H
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask=k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
m_i = m_ij
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += H
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, m_i
@triton.jit
def _attn_fwd(
Q,
K,
V,
cu_seqlens_q,
cu_seqlens_k,
Q_scale,
K_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
Out,
stride_qh,
stride_qn,
stride_kh,
stride_kn,
stride_vh,
stride_vn,
stride_oh,
stride_on,
H: tl.constexpr,
num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
if (start_m * BLOCK_M) >= qo_len:
return
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
k_scale_offset = (
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
)
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset
K_ptrs = (
K
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
V_ptrs = (
V
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
O_block_ptr = (
Out
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
+ offs_m[:, None] * stride_on
+ offs_k[None, :]
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
4 - STAGE,
offs_m,
offs_n,
)
acc, l_i, _ = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
2,
offs_m,
offs_n,
)
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
@triton.jit
def _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
k,
v,
do,
q_scale,
k_scale,
kv_len,
stride_kn,
stride_vn,
start_m,
H,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
k += stride_kn * lo
v += stride_vn * lo
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_mask = offs_n[None, :] < (kv_len - start_n)
k_curr = tl.load(k, mask=k_mask)
v_curr = tl.load(v, mask=k_mask)
k_scale_curr = tl.load(k_scale)
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
s = s + tl.where(mask, 0.0, -float("inf"))
m_ij = tl.maximum(m_i, tl.max(s, 1))
s = s - m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(s, 1))
s = s - m_ij[:, None]
p = tl.math.exp2(s)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
m_i = m_ij
p = p / l_i[:, None] # Normalize probabilities
# Compute gradients
# Compute softmax gradient
do_scaled = do / l_i[:, None]
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
dv_acc += dv_contrib
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
# Compute ds (gradient w.r.t. logits s)
p_dp = p * dp
sum_p_dp = tl.sum(p_dp, axis=1)
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
# Compute gradients w.r.t q and k
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
dq_acc += dq_contrib * (q_scale * k_scale_curr)
dk_acc += dk_contrib * (q_scale * k_scale_curr)
k += BLOCK_N * stride_kn
k_scale += H
v += BLOCK_N * stride_vn
return dq_acc, dk_acc, dv_acc, l_i, m_i
@triton.jit
def _attn_bwd(
DO,
Q,
K,
V,
cu_seqlens_q,
cu_seqlens_k,
Q_scale,
K_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
L,
M,
DQ,
DK,
DV,
stride_qh,
stride_qn,
stride_kh,
stride_kn,
stride_vh,
stride_vn,
H: tl.constexpr,
num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
if (start_m * BLOCK_M) >= qo_len:
return
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
k_scale_offset = (
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
)
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
DO_ptrs = (
DO
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset
K_ptrs = (
K
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
V_ptrs = (
V
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
DQ_ptrs = (
DQ
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
DK_ptrs = (
DK
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
DV_ptrs = (
DV
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
L_ptrs = L + (cu_seqlens_q_start + offs_m)
M_ptrs = M + (cu_seqlens_q_start + offs_m)
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
K_ptrs,
V_ptrs,
do,
q_scale,
K_scale_ptr,
kv_len,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
4 - STAGE,
offs_m,
offs_n,
)
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
K_ptrs,
V_ptrs,
do,
q_scale,
K_scale_ptr,
kv_len,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
2,
offs_m,
offs_n,
)
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
def forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
b = cu_seqlens_q.shape[0] - 1
_, h_qo, head_dim = q.shape
_, h_kv, _ = k.shape
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
_attn_fwd[grid](
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
o,
q.stride(1),
q.stride(0),
k.stride(1),
k.stride(0),
v.stride(1),
v.stride(0),
o.stride(1),
o.stride(0),
h_qo,
num_kv_groups,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return o
def backward(
do,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
l,
m,
output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3
device = q.device
dtype = q.dtype
b = cu_seqlens_q.shape[0] - 1
_, h_qo, head_dim = q.shape
_, h_kv, _ = k.shape
num_kv_groups = h_qo // h_kv
dq = torch.zeros_like(q, dtype=output_dtype)
dk = torch.zeros_like(k, dtype=output_dtype)
dv = torch.zeros_like(v, dtype=output_dtype)
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
_attn_bwd[grid](
do,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
l,
m,
dq,
dk,
dv,
q.stride(1),
q.stride(0),
k.stride(1),
k.stride(0),
v.stride(1),
v.stride(0),
h_qo,
num_kv_groups,
HEAD_DIM=head_dim,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return dq, dk, dv
# class TritonAttentionFunction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
# return output
#
# @staticmethod
# def backward(ctx, do):
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
# dq, dk, dv = backward(
# do, q, k, v,
# cu_seqlens_q, cu_seqlens_k,
# q.shape[0], q_scale, k_scale,
# cu_seqlens_q_scale, cu_seqlens_k_scale,
# l, m,
# )
# return dq, dk, dv, None, None, None, None, None, None

View File

@@ -0,0 +1,158 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import triton
import triton.language as tl
@triton.jit
def quant_per_block_int8_kernel(
Input,
Output,
Scale,
cu_seqlens_input,
cu_seqlens_scale,
stride_ih,
stride_in,
stride_oh,
stride_on,
sm_scale,
H: tl.constexpr,
C: tl.constexpr,
BLK: tl.constexpr,
):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
L = cu_seqlens_input_end - cu_seqlens_input_start
if (off_blk * BLK) >= L:
return
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = (
Input
+ cu_seqlens_input_start * stride_in
+ off_h * stride_ih
+ offs_n[:, None] * stride_in
+ offs_k[None, :]
)
output_ptrs = (
Output
+ cu_seqlens_input_start * stride_on
+ off_h * stride_oh
+ offs_n[:, None] * stride_on
+ offs_k[None, :]
)
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.0
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)
def per_block_int8(
q,
k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLKQ=128,
BLKK=64,
sm_scale=None,
):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
h_qo = q.shape[1]
h_kv = k.shape[1]
head_dim = q.shape[-1]
b = cu_seqlens_q.shape[0] - 1
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
cu_seqlens_q_scale = torch.nn.functional.pad(
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
)
cu_seqlens_k_scale = torch.nn.functional.pad(
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
)
q_scale = torch.empty(
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
)
k_scale = torch.empty(
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
)
if sm_scale is None:
sm_scale = head_dim**-0.5
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q,
q_int8,
q_scale,
cu_seqlens_q,
cu_seqlens_q_scale,
q.stride(1),
q.stride(0),
q_int8.stride(1),
q_int8.stride(0),
sm_scale=(sm_scale * 1.44269504),
H=h_qo,
C=head_dim,
BLK=BLKQ,
)
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k,
k_int8,
k_scale,
cu_seqlens_k,
cu_seqlens_k_scale,
k.stride(1),
k.stride(0),
k_int8.stride(1),
k_int8.stride(0),
sm_scale=1.0,
H=h_kv,
C=head_dim,
BLK=BLKK,
)
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale

View File

@@ -46,6 +46,7 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -707,6 +708,7 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
monkeypatch_sdp_w_sage_attention()
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access