Compare commits
1 Commits
preprocess
...
sageattent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8acc72dd8 |
0
src/axolotl/integrations/sageattention/__init__.py
Normal file
0
src/axolotl/integrations/sageattention/__init__.py
Normal file
361
src/axolotl/integrations/sageattention/lib/core.py
Normal file
361
src/axolotl/integrations/sageattention/lib/core.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Function
|
||||||
|
|
||||||
|
from .triton.attn_qk_int8_per_block_causal_varlen import (
|
||||||
|
backward as sageattn_varlen_backward,
|
||||||
|
)
|
||||||
|
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
|
||||||
|
from .triton.quant_per_block_varlen import (
|
||||||
|
per_block_int8 as per_block_int8_varlen_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_arch_versions():
|
||||||
|
cuda_archs = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
cuda_archs.append(f"sm{major}{minor}")
|
||||||
|
return cuda_archs
|
||||||
|
|
||||||
|
|
||||||
|
def sageattn_varlen(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
sm_scale: Optional[float] = None,
|
||||||
|
smooth_k: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
q : torch.Tensor
|
||||||
|
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||||
|
|
||||||
|
k : torch.Tensor
|
||||||
|
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||||
|
|
||||||
|
v : torch.Tensor
|
||||||
|
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||||
|
|
||||||
|
cu_seqlens_q : torch.Tensor
|
||||||
|
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
|
||||||
|
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||||
|
|
||||||
|
cu_seqlens_k : torch.Tensor
|
||||||
|
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
|
||||||
|
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||||
|
|
||||||
|
max_seqlen_q : int
|
||||||
|
The maximum sequence length for the query tensor in the batch.
|
||||||
|
|
||||||
|
max_seqlen_k : int
|
||||||
|
The maximum sequence length for the key and value tensors in the batch.
|
||||||
|
|
||||||
|
is_causal : bool
|
||||||
|
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
|
sm_scale : Optional[float]
|
||||||
|
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||||
|
|
||||||
|
smooth_k : bool
|
||||||
|
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
|
||||||
|
Default: True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||||
|
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
|
||||||
|
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
|
||||||
|
- All tensors must be on the same cuda device.
|
||||||
|
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype = q.dtype
|
||||||
|
assert q.is_cuda, "Input tensors must be on cuda."
|
||||||
|
assert dtype in [
|
||||||
|
torch.float16,
|
||||||
|
torch.bfloat16,
|
||||||
|
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
||||||
|
assert q.device == k.device == v.device, "All tensors must be on the same device."
|
||||||
|
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
|
||||||
|
|
||||||
|
head_dim = q.size(-1)
|
||||||
|
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
|
||||||
|
|
||||||
|
assert (
|
||||||
|
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
|
||||||
|
), "Last dim of qkv must be contiguous."
|
||||||
|
assert (
|
||||||
|
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
|
||||||
|
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
|
||||||
|
|
||||||
|
if dtype == torch.bfloat16 or dtype == torch.float32:
|
||||||
|
v = v.to(torch.float16)
|
||||||
|
|
||||||
|
if smooth_k:
|
||||||
|
km = k.mean(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
|
||||||
|
k -= km
|
||||||
|
|
||||||
|
(
|
||||||
|
q_int8,
|
||||||
|
q_scale,
|
||||||
|
k_int8,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
) = per_block_int8_varlen_triton(
|
||||||
|
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
o = attn_true_varlen(
|
||||||
|
q_int8,
|
||||||
|
k_int8,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
output_dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class SageAttentionFunction(Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
|
||||||
|
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||||
|
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||||
|
attn_mask: Optional[Tensor], mask tensor
|
||||||
|
dropout_p: float, dropout probability
|
||||||
|
is_causal: bool, whether to apply causal masking
|
||||||
|
scale: Optional[float], scaling factor for attention scores
|
||||||
|
"""
|
||||||
|
# Ensure inputs are contiguous
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
|
||||||
|
# Handle default scale
|
||||||
|
if scale is None:
|
||||||
|
scale = 1.0 / (query.size(-1) ** 0.5)
|
||||||
|
|
||||||
|
# Save parameters needed for backward
|
||||||
|
ctx.scale = scale
|
||||||
|
ctx.is_causal = is_causal
|
||||||
|
ctx.dropout_p = dropout_p
|
||||||
|
ctx.attn_mask = attn_mask
|
||||||
|
|
||||||
|
# Prepare cumulative sequence lengths and max sequence lengths
|
||||||
|
# Assuming batch sizes are consistent across query, key, and value
|
||||||
|
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||||
|
seq_len_k = key.shape[2]
|
||||||
|
|
||||||
|
# Flatten batch and head dimensions
|
||||||
|
q = query.view(
|
||||||
|
-1, seq_len_q, head_dim
|
||||||
|
) # [batch_size * num_heads, seq_len_q, head_dim]
|
||||||
|
k = key.view(-1, seq_len_k, head_dim)
|
||||||
|
v = value.view(-1, seq_len_k, head_dim)
|
||||||
|
|
||||||
|
# Create cumulative sequence lengths
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size * num_heads + 1) * seq_len_q,
|
||||||
|
seq_len_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=query.device,
|
||||||
|
)
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size * num_heads + 1) * seq_len_k,
|
||||||
|
seq_len_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=key.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seq_len_q
|
||||||
|
max_seqlen_k = seq_len_k
|
||||||
|
|
||||||
|
# Call your custom per-block int8 quantization function
|
||||||
|
(
|
||||||
|
q_int8,
|
||||||
|
q_scale,
|
||||||
|
k_int8,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
) = per_block_int8_varlen_triton(
|
||||||
|
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call your custom attention function
|
||||||
|
if is_causal:
|
||||||
|
output = attn_true_varlen(
|
||||||
|
q_int8,
|
||||||
|
k_int8,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
output_dtype=query.dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-causal attention is not implemented yet.")
|
||||||
|
|
||||||
|
# Reshape output to match the expected shape
|
||||||
|
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||||
|
|
||||||
|
# Save tensors for backward
|
||||||
|
ctx.save_for_backward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
q_int8,
|
||||||
|
k_int8,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
q_int8,
|
||||||
|
k_int8,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
output,
|
||||||
|
) = ctx.saved_tensors
|
||||||
|
|
||||||
|
scale = ctx.scale
|
||||||
|
is_causal = ctx.is_causal
|
||||||
|
dropout_p = ctx.dropout_p
|
||||||
|
attn_mask = ctx.attn_mask
|
||||||
|
|
||||||
|
# Flatten batch and head dimensions
|
||||||
|
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||||
|
seq_len_k = key.shape[2]
|
||||||
|
grad_output = grad_output.contiguous()
|
||||||
|
do = grad_output.view(-1, seq_len_q, head_dim)
|
||||||
|
|
||||||
|
# Compute gradients w.r.t. q, k, v
|
||||||
|
dq, dk, dv = sageattn_varlen_backward(
|
||||||
|
do,
|
||||||
|
query.view(-1, seq_len_q, head_dim),
|
||||||
|
key.view(-1, seq_len_k, head_dim),
|
||||||
|
value.view(-1, seq_len_k, head_dim),
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seq_len_q,
|
||||||
|
seq_len_k,
|
||||||
|
q_int8,
|
||||||
|
k_int8,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
scale,
|
||||||
|
is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape gradients to match the input shapes
|
||||||
|
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||||
|
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||||
|
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||||
|
|
||||||
|
# Handle optional arguments
|
||||||
|
d_attn_mask = None # Assuming attn_mask does not require gradients
|
||||||
|
d_dropout_p = (
|
||||||
|
None # Dropout probability is a hyperparameter, typically not optimized
|
||||||
|
)
|
||||||
|
d_is_causal = None # Not differentiable
|
||||||
|
d_scale = None # If scale is a tensor and requires grad, compute its gradient
|
||||||
|
|
||||||
|
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False,
|
||||||
|
scale=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Custom scaled dot product attention using SageAttentionFunction.
|
||||||
|
"""
|
||||||
|
return SageAttentionFunction.apply(
|
||||||
|
query, key, value, attn_mask, dropout_p, is_causal, scale
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def monkeypatch_sdp_w_sage_attention():
|
||||||
|
"""
|
||||||
|
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
|
||||||
|
"""
|
||||||
|
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||||
@@ -0,0 +1,622 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
q_scale,
|
||||||
|
kv_len,
|
||||||
|
K_ptrs,
|
||||||
|
K_scale_ptr,
|
||||||
|
V_ptrs,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr,
|
||||||
|
offs_m: tl.constexpr,
|
||||||
|
offs_n: tl.constexpr,
|
||||||
|
):
|
||||||
|
if STAGE == 1:
|
||||||
|
lo, hi = 0, start_m * BLOCK_M
|
||||||
|
elif STAGE == 2:
|
||||||
|
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||||
|
lo = tl.multiple_of(lo, BLOCK_M)
|
||||||
|
K_scale_ptr += (lo // BLOCK_N) * H
|
||||||
|
K_ptrs += stride_kn * lo
|
||||||
|
V_ptrs += stride_vn * lo
|
||||||
|
for start_n in range(lo, hi, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||||
|
k = tl.load(K_ptrs, mask=k_mask)
|
||||||
|
k_scale = tl.load(K_scale_ptr)
|
||||||
|
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||||
|
|
||||||
|
if STAGE == 2:
|
||||||
|
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||||
|
qk = qk + tl.where(mask, 0, -1.0e6)
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
qk -= m_ij[:, None]
|
||||||
|
else:
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
qk = qk - m_ij[:, None]
|
||||||
|
|
||||||
|
p = tl.math.exp2(qk)
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
|
||||||
|
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
||||||
|
p = p.to(tl.float16)
|
||||||
|
|
||||||
|
acc += tl.dot(p, v, out_dtype=tl.float16)
|
||||||
|
m_i = m_ij
|
||||||
|
K_ptrs += BLOCK_N * stride_kn
|
||||||
|
K_scale_ptr += H
|
||||||
|
V_ptrs += BLOCK_N * stride_vn
|
||||||
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
Q_scale,
|
||||||
|
K_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
Out,
|
||||||
|
stride_qh,
|
||||||
|
stride_qn,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_vh,
|
||||||
|
stride_vn,
|
||||||
|
stride_oh,
|
||||||
|
stride_on,
|
||||||
|
H: tl.constexpr,
|
||||||
|
num_kv_groups: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
|
||||||
|
off_z = tl.program_id(2).to(tl.int64)
|
||||||
|
off_h = tl.program_id(1).to(tl.int64)
|
||||||
|
|
||||||
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||||
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||||
|
|
||||||
|
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||||
|
|
||||||
|
if (start_m * BLOCK_M) >= qo_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||||
|
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||||
|
|
||||||
|
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||||
|
k_scale_offset = (
|
||||||
|
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||||
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||||
|
|
||||||
|
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||||
|
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_k = tl.arange(0, HEAD_DIM)
|
||||||
|
Q_ptrs = (
|
||||||
|
Q
|
||||||
|
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||||
|
+ offs_m[:, None] * stride_qn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
Q_scale_ptr = Q_scale + q_scale_offset
|
||||||
|
K_ptrs = (
|
||||||
|
K
|
||||||
|
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||||
|
+ offs_n[None, :] * stride_kn
|
||||||
|
+ offs_k[:, None]
|
||||||
|
)
|
||||||
|
K_scale_ptr = K_scale + k_scale_offset
|
||||||
|
V_ptrs = (
|
||||||
|
V
|
||||||
|
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||||
|
+ offs_n[:, None] * stride_vn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
O_block_ptr = (
|
||||||
|
Out
|
||||||
|
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
|
||||||
|
+ offs_m[:, None] * stride_on
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||||
|
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
|
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||||
|
q_scale = tl.load(Q_scale_ptr)
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
q_scale,
|
||||||
|
kv_len,
|
||||||
|
K_ptrs,
|
||||||
|
K_scale_ptr,
|
||||||
|
V_ptrs,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H // num_kv_groups,
|
||||||
|
BLOCK_M,
|
||||||
|
HEAD_DIM,
|
||||||
|
BLOCK_N,
|
||||||
|
4 - STAGE,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
acc, l_i, _ = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
q_scale,
|
||||||
|
kv_len,
|
||||||
|
K_ptrs,
|
||||||
|
K_scale_ptr,
|
||||||
|
V_ptrs,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H // num_kv_groups,
|
||||||
|
BLOCK_M,
|
||||||
|
HEAD_DIM,
|
||||||
|
BLOCK_N,
|
||||||
|
2,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
)
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_bwd_inner(
|
||||||
|
dq_acc,
|
||||||
|
dk_acc,
|
||||||
|
dv_acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
do,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
kv_len,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr,
|
||||||
|
offs_m: tl.constexpr,
|
||||||
|
offs_n: tl.constexpr,
|
||||||
|
):
|
||||||
|
if STAGE == 1:
|
||||||
|
lo, hi = 0, start_m * BLOCK_M
|
||||||
|
elif STAGE == 2:
|
||||||
|
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||||
|
lo = tl.multiple_of(lo, BLOCK_M)
|
||||||
|
k += stride_kn * lo
|
||||||
|
v += stride_vn * lo
|
||||||
|
|
||||||
|
for start_n in range(lo, hi, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||||
|
k_curr = tl.load(k, mask=k_mask)
|
||||||
|
v_curr = tl.load(v, mask=k_mask)
|
||||||
|
k_scale_curr = tl.load(k_scale)
|
||||||
|
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
|
||||||
|
|
||||||
|
if STAGE == 2:
|
||||||
|
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||||
|
s = s + tl.where(mask, 0.0, -float("inf"))
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||||
|
s = s - m_ij[:, None]
|
||||||
|
else:
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||||
|
s = s - m_ij[:, None]
|
||||||
|
|
||||||
|
p = tl.math.exp2(s)
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
m_i = m_ij
|
||||||
|
|
||||||
|
p = p / l_i[:, None] # Normalize probabilities
|
||||||
|
|
||||||
|
# Compute gradients
|
||||||
|
# Compute softmax gradient
|
||||||
|
do_scaled = do / l_i[:, None]
|
||||||
|
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
|
||||||
|
dv_acc += dv_contrib
|
||||||
|
|
||||||
|
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
|
||||||
|
|
||||||
|
# Compute ds (gradient w.r.t. logits s)
|
||||||
|
p_dp = p * dp
|
||||||
|
sum_p_dp = tl.sum(p_dp, axis=1)
|
||||||
|
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
|
||||||
|
|
||||||
|
# Compute gradients w.r.t q and k
|
||||||
|
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
|
||||||
|
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
|
||||||
|
|
||||||
|
dq_acc += dq_contrib * (q_scale * k_scale_curr)
|
||||||
|
dk_acc += dk_contrib * (q_scale * k_scale_curr)
|
||||||
|
|
||||||
|
k += BLOCK_N * stride_kn
|
||||||
|
k_scale += H
|
||||||
|
v += BLOCK_N * stride_vn
|
||||||
|
|
||||||
|
return dq_acc, dk_acc, dv_acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_bwd(
|
||||||
|
DO,
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
Q_scale,
|
||||||
|
K_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
L,
|
||||||
|
M,
|
||||||
|
DQ,
|
||||||
|
DK,
|
||||||
|
DV,
|
||||||
|
stride_qh,
|
||||||
|
stride_qn,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_vh,
|
||||||
|
stride_vn,
|
||||||
|
H: tl.constexpr,
|
||||||
|
num_kv_groups: tl.constexpr,
|
||||||
|
HEAD_DIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
STAGE: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_z = tl.program_id(2).to(tl.int64)
|
||||||
|
off_h = tl.program_id(1).to(tl.int64)
|
||||||
|
|
||||||
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||||
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||||
|
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||||
|
|
||||||
|
if (start_m * BLOCK_M) >= qo_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||||
|
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||||
|
|
||||||
|
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||||
|
k_scale_offset = (
|
||||||
|
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||||
|
)
|
||||||
|
|
||||||
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||||
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||||
|
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||||
|
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_k = tl.arange(0, HEAD_DIM)
|
||||||
|
Q_ptrs = (
|
||||||
|
Q
|
||||||
|
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||||
|
+ offs_m[:, None] * stride_qn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
DO_ptrs = (
|
||||||
|
DO
|
||||||
|
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||||
|
+ offs_m[:, None] * stride_qn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
Q_scale_ptr = Q_scale + q_scale_offset
|
||||||
|
K_ptrs = (
|
||||||
|
K
|
||||||
|
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||||
|
+ offs_n[None, :] * stride_kn
|
||||||
|
+ offs_k[:, None]
|
||||||
|
)
|
||||||
|
K_scale_ptr = K_scale + k_scale_offset
|
||||||
|
V_ptrs = (
|
||||||
|
V
|
||||||
|
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||||
|
+ offs_n[:, None] * stride_vn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
DQ_ptrs = (
|
||||||
|
DQ
|
||||||
|
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||||
|
+ offs_m[:, None] * stride_qn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
DK_ptrs = (
|
||||||
|
DK
|
||||||
|
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||||
|
+ offs_n[None, :] * stride_kn
|
||||||
|
+ offs_k[:, None]
|
||||||
|
)
|
||||||
|
DV_ptrs = (
|
||||||
|
DV
|
||||||
|
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||||
|
+ offs_n[:, None] * stride_vn
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
L_ptrs = L + (cu_seqlens_q_start + offs_m)
|
||||||
|
M_ptrs = M + (cu_seqlens_q_start + offs_m)
|
||||||
|
|
||||||
|
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
|
||||||
|
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
|
||||||
|
|
||||||
|
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||||
|
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||||
|
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||||
|
|
||||||
|
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||||
|
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
|
||||||
|
q_scale = tl.load(Q_scale_ptr)
|
||||||
|
|
||||||
|
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||||
|
dq_acc,
|
||||||
|
dk_acc,
|
||||||
|
dv_acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_ptrs,
|
||||||
|
V_ptrs,
|
||||||
|
do,
|
||||||
|
q_scale,
|
||||||
|
K_scale_ptr,
|
||||||
|
kv_len,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H // num_kv_groups,
|
||||||
|
BLOCK_M,
|
||||||
|
HEAD_DIM,
|
||||||
|
BLOCK_N,
|
||||||
|
4 - STAGE,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||||
|
dq_acc,
|
||||||
|
dk_acc,
|
||||||
|
dv_acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_ptrs,
|
||||||
|
V_ptrs,
|
||||||
|
do,
|
||||||
|
q_scale,
|
||||||
|
K_scale_ptr,
|
||||||
|
kv_len,
|
||||||
|
stride_kn,
|
||||||
|
stride_vn,
|
||||||
|
start_m,
|
||||||
|
H // num_kv_groups,
|
||||||
|
BLOCK_M,
|
||||||
|
HEAD_DIM,
|
||||||
|
BLOCK_N,
|
||||||
|
2,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
|
||||||
|
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
|
||||||
|
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
output_dtype=torch.float16,
|
||||||
|
):
|
||||||
|
BLOCK_M = 128
|
||||||
|
BLOCK_N = 64
|
||||||
|
stage = 3
|
||||||
|
|
||||||
|
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
||||||
|
|
||||||
|
b = cu_seqlens_q.shape[0] - 1
|
||||||
|
_, h_qo, head_dim = q.shape
|
||||||
|
_, h_kv, _ = k.shape
|
||||||
|
|
||||||
|
HEAD_DIM_K = head_dim
|
||||||
|
num_kv_groups = h_qo // h_kv
|
||||||
|
|
||||||
|
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||||
|
_attn_fwd[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
o,
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(0),
|
||||||
|
o.stride(1),
|
||||||
|
o.stride(0),
|
||||||
|
h_qo,
|
||||||
|
num_kv_groups,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
HEAD_DIM=HEAD_DIM_K,
|
||||||
|
STAGE=stage,
|
||||||
|
num_warps=4 if head_dim == 64 else 8,
|
||||||
|
num_stages=4,
|
||||||
|
)
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
def backward(
|
||||||
|
do,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
l,
|
||||||
|
m,
|
||||||
|
output_dtype=torch.float16,
|
||||||
|
):
|
||||||
|
BLOCK_M = 128
|
||||||
|
BLOCK_N = 64
|
||||||
|
stage = 3
|
||||||
|
|
||||||
|
device = q.device
|
||||||
|
dtype = q.dtype
|
||||||
|
b = cu_seqlens_q.shape[0] - 1
|
||||||
|
_, h_qo, head_dim = q.shape
|
||||||
|
_, h_kv, _ = k.shape
|
||||||
|
num_kv_groups = h_qo // h_kv
|
||||||
|
|
||||||
|
dq = torch.zeros_like(q, dtype=output_dtype)
|
||||||
|
dk = torch.zeros_like(k, dtype=output_dtype)
|
||||||
|
dv = torch.zeros_like(v, dtype=output_dtype)
|
||||||
|
|
||||||
|
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||||
|
_attn_bwd[grid](
|
||||||
|
do,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
q_scale,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
l,
|
||||||
|
m,
|
||||||
|
dq,
|
||||||
|
dk,
|
||||||
|
dv,
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(0),
|
||||||
|
v.stride(1),
|
||||||
|
v.stride(0),
|
||||||
|
h_qo,
|
||||||
|
num_kv_groups,
|
||||||
|
HEAD_DIM=head_dim,
|
||||||
|
BLOCK_M=BLOCK_M,
|
||||||
|
BLOCK_N=BLOCK_N,
|
||||||
|
STAGE=stage,
|
||||||
|
num_warps=4 if head_dim == 64 else 8,
|
||||||
|
num_stages=4,
|
||||||
|
)
|
||||||
|
return dq, dk, dv
|
||||||
|
|
||||||
|
|
||||||
|
# class TritonAttentionFunction(torch.autograd.Function):
|
||||||
|
# @staticmethod
|
||||||
|
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
|
||||||
|
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||||
|
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||||
|
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||||
|
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||||
|
# return output
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# def backward(ctx, do):
|
||||||
|
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
|
||||||
|
# dq, dk, dv = backward(
|
||||||
|
# do, q, k, v,
|
||||||
|
# cu_seqlens_q, cu_seqlens_k,
|
||||||
|
# q.shape[0], q_scale, k_scale,
|
||||||
|
# cu_seqlens_q_scale, cu_seqlens_k_scale,
|
||||||
|
# l, m,
|
||||||
|
# )
|
||||||
|
# return dq, dk, dv, None, None, None, None, None, None
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2024 by SageAttention team.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def quant_per_block_int8_kernel(
|
||||||
|
Input,
|
||||||
|
Output,
|
||||||
|
Scale,
|
||||||
|
cu_seqlens_input,
|
||||||
|
cu_seqlens_scale,
|
||||||
|
stride_ih,
|
||||||
|
stride_in,
|
||||||
|
stride_oh,
|
||||||
|
stride_on,
|
||||||
|
sm_scale,
|
||||||
|
H: tl.constexpr,
|
||||||
|
C: tl.constexpr,
|
||||||
|
BLK: tl.constexpr,
|
||||||
|
):
|
||||||
|
off_blk = tl.program_id(0)
|
||||||
|
off_h = tl.program_id(1)
|
||||||
|
off_b = tl.program_id(2)
|
||||||
|
|
||||||
|
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
|
||||||
|
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
|
||||||
|
|
||||||
|
L = cu_seqlens_input_end - cu_seqlens_input_start
|
||||||
|
|
||||||
|
if (off_blk * BLK) >= L:
|
||||||
|
return
|
||||||
|
|
||||||
|
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
|
||||||
|
|
||||||
|
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
||||||
|
offs_k = tl.arange(0, C)
|
||||||
|
|
||||||
|
input_ptrs = (
|
||||||
|
Input
|
||||||
|
+ cu_seqlens_input_start * stride_in
|
||||||
|
+ off_h * stride_ih
|
||||||
|
+ offs_n[:, None] * stride_in
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
output_ptrs = (
|
||||||
|
Output
|
||||||
|
+ cu_seqlens_input_start * stride_on
|
||||||
|
+ off_h * stride_oh
|
||||||
|
+ offs_n[:, None] * stride_on
|
||||||
|
+ offs_k[None, :]
|
||||||
|
)
|
||||||
|
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
|
||||||
|
|
||||||
|
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
||||||
|
x = x.to(tl.float32)
|
||||||
|
x *= sm_scale
|
||||||
|
scale = tl.max(tl.abs(x)) / 127.0
|
||||||
|
x_int8 = x / scale
|
||||||
|
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
||||||
|
x_int8 = x_int8.to(tl.int8)
|
||||||
|
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
||||||
|
tl.store(scale_ptrs, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def per_block_int8(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
BLKQ=128,
|
||||||
|
BLKK=64,
|
||||||
|
sm_scale=None,
|
||||||
|
):
|
||||||
|
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
||||||
|
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
||||||
|
|
||||||
|
h_qo = q.shape[1]
|
||||||
|
h_kv = k.shape[1]
|
||||||
|
head_dim = q.shape[-1]
|
||||||
|
|
||||||
|
b = cu_seqlens_q.shape[0] - 1
|
||||||
|
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||||
|
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||||
|
|
||||||
|
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
|
||||||
|
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
|
||||||
|
|
||||||
|
cu_seqlens_q_scale = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
|
||||||
|
)
|
||||||
|
cu_seqlens_k_scale = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
q_scale = torch.empty(
|
||||||
|
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
k_scale = torch.empty(
|
||||||
|
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
if sm_scale is None:
|
||||||
|
sm_scale = head_dim**-0.5
|
||||||
|
|
||||||
|
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
|
||||||
|
quant_per_block_int8_kernel[grid](
|
||||||
|
q,
|
||||||
|
q_int8,
|
||||||
|
q_scale,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_q_scale,
|
||||||
|
q.stride(1),
|
||||||
|
q.stride(0),
|
||||||
|
q_int8.stride(1),
|
||||||
|
q_int8.stride(0),
|
||||||
|
sm_scale=(sm_scale * 1.44269504),
|
||||||
|
H=h_qo,
|
||||||
|
C=head_dim,
|
||||||
|
BLK=BLKQ,
|
||||||
|
)
|
||||||
|
|
||||||
|
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
|
||||||
|
quant_per_block_int8_kernel[grid](
|
||||||
|
k,
|
||||||
|
k_int8,
|
||||||
|
k_scale,
|
||||||
|
cu_seqlens_k,
|
||||||
|
cu_seqlens_k_scale,
|
||||||
|
k.stride(1),
|
||||||
|
k.stride(0),
|
||||||
|
k_int8.stride(1),
|
||||||
|
k_int8.stride(0),
|
||||||
|
sm_scale=1.0,
|
||||||
|
H=h_kv,
|
||||||
|
C=head_dim,
|
||||||
|
BLK=BLKK,
|
||||||
|
)
|
||||||
|
|
||||||
|
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale
|
||||||
@@ -46,6 +46,7 @@ from transformers.integrations.deepspeed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
|
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -707,6 +708,7 @@ class ModelLoader:
|
|||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"sdpa"
|
"sdpa"
|
||||||
)
|
)
|
||||||
|
monkeypatch_sdp_w_sage_attention()
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
self.model_kwargs["attn_implementation"] = "eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
|||||||
Reference in New Issue
Block a user