diff --git a/src/axolotl/integrations/sageattention/__init__.py b/src/axolotl/integrations/sageattention/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/sageattention/lib/core.py b/src/axolotl/integrations/sageattention/lib/core.py new file mode 100644 index 000000000..a9f389d98 --- /dev/null +++ b/src/axolotl/integrations/sageattention/lib/core.py @@ -0,0 +1,361 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any, Optional + +import torch +from torch.autograd import Function + +from .triton.attn_qk_int8_per_block_causal_varlen import ( + backward as sageattn_varlen_backward, +) +from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen +from .triton.quant_per_block_varlen import ( + per_block_int8 as per_block_int8_varlen_triton, +) + + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + + +def sageattn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: Optional[float] = None, + smooth_k: bool = True, + **kwargs: Any, +) -> torch.Tensor: + """ + + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. + + k : torch.Tensor + The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. + + v : torch.Tensor + The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. + + cu_seqlens_q : torch.Tensor + The cumulative sequence lengths for the query sequences in the batch, used to index into `q`. + Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. + + cu_seqlens_k : torch.Tensor + The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`. + Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. + + max_seqlen_q : int + The maximum sequence length for the query tensor in the batch. + + max_seqlen_k : int + The maximum sequence length for the key and value tensors in the batch. + + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence. + Default: False. + + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + + Returns + ------- + torch.Tensor + The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``. + + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. + - The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``. + - All tensors must be on the same cuda device. + - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. + """ + + dtype = q.dtype + assert q.is_cuda, "Input tensors must be on cuda." + assert dtype in [ + torch.float16, + torch.bfloat16, + ], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + + head_dim = q.size(-1) + assert head_dim in [64, 128], "varlen only support head_dim [64, 128]." + + assert ( + q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1 + ), "Last dim of qkv must be contiguous." + assert ( + cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous() + ), "cu_seqlens_q and cu_seqlens_k must be contiguous." + + if dtype == torch.bfloat16 or dtype == torch.float32: + v = v.to(torch.float16) + + if smooth_k: + km = k.mean( + dim=0, keepdim=True + ) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel. + k -= km + + ( + q_int8, + q_scale, + k_int8, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + ) = per_block_int8_varlen_triton( + q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale + ) + + o = attn_true_varlen( + q_int8, + k_int8, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=dtype, + ) + + return o + + +class SageAttentionFunction(Function): + @staticmethod + def forward( + ctx, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + ): + """ + query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim] + key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim] + value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim] + attn_mask: Optional[Tensor], mask tensor + dropout_p: float, dropout probability + is_causal: bool, whether to apply causal masking + scale: Optional[float], scaling factor for attention scores + """ + # Ensure inputs are contiguous + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Handle default scale + if scale is None: + scale = 1.0 / (query.size(-1) ** 0.5) + + # Save parameters needed for backward + ctx.scale = scale + ctx.is_causal = is_causal + ctx.dropout_p = dropout_p + ctx.attn_mask = attn_mask + + # Prepare cumulative sequence lengths and max sequence lengths + # Assuming batch sizes are consistent across query, key, and value + batch_size, num_heads, seq_len_q, head_dim = query.shape + seq_len_k = key.shape[2] + + # Flatten batch and head dimensions + q = query.view( + -1, seq_len_q, head_dim + ) # [batch_size * num_heads, seq_len_q, head_dim] + k = key.view(-1, seq_len_k, head_dim) + v = value.view(-1, seq_len_k, head_dim) + + # Create cumulative sequence lengths + cu_seqlens_q = torch.arange( + 0, + (batch_size * num_heads + 1) * seq_len_q, + seq_len_q, + dtype=torch.int32, + device=query.device, + ) + cu_seqlens_k = torch.arange( + 0, + (batch_size * num_heads + 1) * seq_len_k, + seq_len_k, + dtype=torch.int32, + device=key.device, + ) + max_seqlen_q = seq_len_q + max_seqlen_k = seq_len_k + + # Call your custom per-block int8 quantization function + ( + q_int8, + q_scale, + k_int8, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + ) = per_block_int8_varlen_triton( + q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale + ) + + # Call your custom attention function + if is_causal: + output = attn_true_varlen( + q_int8, + k_int8, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=query.dtype, + ) + else: + raise NotImplementedError("Non-causal attention is not implemented yet.") + + # Reshape output to match the expected shape + output = output.view(batch_size, num_heads, seq_len_q, head_dim) + + # Save tensors for backward + ctx.save_for_backward( + query, + key, + value, + q_int8, + k_int8, + q_scale, + k_scale, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output, + ) + + return output + + @staticmethod + def backward(ctx, grad_output): + ( + query, + key, + value, + q_int8, + k_int8, + q_scale, + k_scale, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output, + ) = ctx.saved_tensors + + scale = ctx.scale + is_causal = ctx.is_causal + dropout_p = ctx.dropout_p + attn_mask = ctx.attn_mask + + # Flatten batch and head dimensions + batch_size, num_heads, seq_len_q, head_dim = query.shape + seq_len_k = key.shape[2] + grad_output = grad_output.contiguous() + do = grad_output.view(-1, seq_len_q, head_dim) + + # Compute gradients w.r.t. q, k, v + dq, dk, dv = sageattn_varlen_backward( + do, + query.view(-1, seq_len_q, head_dim), + key.view(-1, seq_len_k, head_dim), + value.view(-1, seq_len_k, head_dim), + cu_seqlens_q, + cu_seqlens_k, + seq_len_q, + seq_len_k, + q_int8, + k_int8, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + scale, + is_causal, + ) + + # Reshape gradients to match the input shapes + dq = dq.view(batch_size, num_heads, seq_len_q, head_dim) + dk = dk.view(batch_size, num_heads, seq_len_k, head_dim) + dv = dv.view(batch_size, num_heads, seq_len_k, head_dim) + + # Handle optional arguments + d_attn_mask = None # Assuming attn_mask does not require gradients + d_dropout_p = ( + None # Dropout probability is a hyperparameter, typically not optimized + ) + d_is_causal = None # Not differentiable + d_scale = None # If scale is a tensor and requires grad, compute its gradient + + return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, +): + """ + Custom scaled dot product attention using SageAttentionFunction. + """ + return SageAttentionFunction.apply( + query, key, value, attn_mask, dropout_p, is_causal, scale + ) + + +def monkeypatch_sdp_w_sage_attention(): + """ + Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction. + """ + torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention diff --git a/src/axolotl/integrations/sageattention/lib/triton/__init__.py b/src/axolotl/integrations/sageattention/lib/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/sageattention/lib/triton/attn_qk_int8_per_block_causal_varlen.py b/src/axolotl/integrations/sageattention/lib/triton/attn_qk_int8_per_block_causal_varlen.py new file mode 100644 index 000000000..3e7892651 --- /dev/null +++ b/src/axolotl/integrations/sageattention/lib/triton/attn_qk_int8_per_block_causal_varlen.py @@ -0,0 +1,622 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + H: tl.constexpr, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += (lo // BLOCK_N) * H + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask=k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + + acc = acc * alpha[:, None] + + v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += H + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd( + Q, + K, + V, + cu_seqlens_q, + cu_seqlens_k, + Q_scale, + K_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + Out, + stride_qh, + stride_qn, + stride_kh, + stride_kn, + stride_vh, + stride_vn, + stride_oh, + stride_on, + H: tl.constexpr, + num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, +): + start_m = tl.program_id(0) + + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + qo_len = cu_seqlens_q_end - cu_seqlens_q_start + + if (start_m * BLOCK_M) >= qo_len: + return + + cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z) + cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z) + + q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H + k_scale_offset = ( + cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups + ) + + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + + kv_len = cu_seqlens_k_end - cu_seqlens_k_start + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = ( + Q + + (cu_seqlens_q_start * stride_qn + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + Q_scale_ptr = Q_scale + q_scale_offset + K_ptrs = ( + K + + (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = ( + V + + (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + O_block_ptr = ( + Out + + (cu_seqlens_q_start * stride_on + off_h * stride_oh) + + offs_m[:, None] * stride_on + + offs_k[None, :] + ) + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + H // num_kv_groups, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + ) + + acc, l_i, _ = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + q_scale, + kv_len, + K_ptrs, + K_scale_ptr, + V_ptrs, + stride_kn, + stride_vn, + start_m, + H // num_kv_groups, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len)) + + +@triton.jit +def _attn_bwd_inner( + dq_acc, + dk_acc, + dv_acc, + l_i, + m_i, + q, + k, + v, + do, + q_scale, + k_scale, + kv_len, + stride_kn, + stride_vn, + start_m, + H, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + k += stride_kn * lo + v += stride_vn * lo + + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k_curr = tl.load(k, mask=k_mask) + v_curr = tl.load(v, mask=k_mask) + k_scale_curr = tl.load(k_scale) + s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + s = s + tl.where(mask, 0.0, -float("inf")) + m_ij = tl.maximum(m_i, tl.max(s, 1)) + s = s - m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(s, 1)) + s = s - m_ij[:, None] + + p = tl.math.exp2(s) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + m_i = m_ij + + p = p / l_i[:, None] # Normalize probabilities + + # Compute gradients + # Compute softmax gradient + do_scaled = do / l_i[:, None] + dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16)) + dv_acc += dv_contrib + + dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T) + + # Compute ds (gradient w.r.t. logits s) + p_dp = p * dp + sum_p_dp = tl.sum(p_dp, axis=1) + ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2 + + # Compute gradients w.r.t q and k + dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16)) + dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16)) + + dq_acc += dq_contrib * (q_scale * k_scale_curr) + dk_acc += dk_contrib * (q_scale * k_scale_curr) + + k += BLOCK_N * stride_kn + k_scale += H + v += BLOCK_N * stride_vn + + return dq_acc, dk_acc, dv_acc, l_i, m_i + + +@triton.jit +def _attn_bwd( + DO, + Q, + K, + V, + cu_seqlens_q, + cu_seqlens_k, + Q_scale, + K_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + L, + M, + DQ, + DK, + DV, + stride_qh, + stride_qn, + stride_kh, + stride_kn, + stride_vh, + stride_vn, + H: tl.constexpr, + num_kv_groups: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, +): + start_m = tl.program_id(0) + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + qo_len = cu_seqlens_q_end - cu_seqlens_q_start + + if (start_m * BLOCK_M) >= qo_len: + return + + cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z) + cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z) + + q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H + k_scale_offset = ( + cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups + ) + + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + kv_len = cu_seqlens_k_end - cu_seqlens_k_start + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = ( + Q + + (cu_seqlens_q_start * stride_qn + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + DO_ptrs = ( + DO + + (cu_seqlens_q_start * stride_qn + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + Q_scale_ptr = Q_scale + q_scale_offset + K_ptrs = ( + K + + (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = ( + V + + (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + DQ_ptrs = ( + DQ + + (cu_seqlens_q_start * stride_qn + off_h * stride_qh) + + offs_m[:, None] * stride_qn + + offs_k[None, :] + ) + DK_ptrs = ( + DK + + (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh) + + offs_n[None, :] * stride_kn + + offs_k[:, None] + ) + DV_ptrs = ( + DV + + (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh) + + offs_n[:, None] * stride_vn + + offs_k[None, :] + ) + L_ptrs = L + (cu_seqlens_q_start + offs_m) + M_ptrs = M + (cu_seqlens_q_start + offs_m) + + m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf")) + l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0) + + dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32) + + q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len) + do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + + dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner( + dq_acc, + dk_acc, + dv_acc, + l_i, + m_i, + q, + K_ptrs, + V_ptrs, + do, + q_scale, + K_scale_ptr, + kv_len, + stride_kn, + stride_vn, + start_m, + H // num_kv_groups, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 4 - STAGE, + offs_m, + offs_n, + ) + + dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner( + dq_acc, + dk_acc, + dv_acc, + l_i, + m_i, + q, + K_ptrs, + V_ptrs, + do, + q_scale, + K_scale_ptr, + kv_len, + stride_kn, + stride_vn, + start_m, + H // num_kv_groups, + BLOCK_M, + HEAD_DIM, + BLOCK_N, + 2, + offs_m, + offs_n, + ) + + tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len) + tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len) + tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len) + + +def forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + output_dtype=torch.float16, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + b = cu_seqlens_q.shape[0] - 1 + _, h_qo, head_dim = q.shape + _, h_kv, _ = k.shape + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b) + _attn_fwd[grid]( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + o, + q.stride(1), + q.stride(0), + k.stride(1), + k.stride(0), + v.stride(1), + v.stride(0), + o.stride(1), + o.stride(0), + h_qo, + num_kv_groups, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4, + ) + return o + + +def backward( + do, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + l, + m, + output_dtype=torch.float16, +): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 + + device = q.device + dtype = q.dtype + b = cu_seqlens_q.shape[0] - 1 + _, h_qo, head_dim = q.shape + _, h_kv, _ = k.shape + num_kv_groups = h_qo // h_kv + + dq = torch.zeros_like(q, dtype=output_dtype) + dk = torch.zeros_like(k, dtype=output_dtype) + dv = torch.zeros_like(v, dtype=output_dtype) + + grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b) + _attn_bwd[grid]( + do, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + q_scale, + k_scale, + cu_seqlens_q_scale, + cu_seqlens_k_scale, + l, + m, + dq, + dk, + dv, + q.stride(1), + q.stride(0), + k.stride(1), + k.stride(0), + v.stride(1), + v.stride(0), + h_qo, + num_kv_groups, + HEAD_DIM=head_dim, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4, + ) + return dq, dk, dv + + +# class TritonAttentionFunction(torch.autograd.Function): +# @staticmethod +# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale): +# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32) +# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32) +# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m) +# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m) +# return output +# +# @staticmethod +# def backward(ctx, do): +# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors +# dq, dk, dv = backward( +# do, q, k, v, +# cu_seqlens_q, cu_seqlens_k, +# q.shape[0], q_scale, k_scale, +# cu_seqlens_q_scale, cu_seqlens_k_scale, +# l, m, +# ) +# return dq, dk, dv, None, None, None, None, None, None diff --git a/src/axolotl/integrations/sageattention/lib/triton/quant_per_block_varlen.py b/src/axolotl/integrations/sageattention/lib/triton/quant_per_block_varlen.py new file mode 100644 index 000000000..b169e052b --- /dev/null +++ b/src/axolotl/integrations/sageattention/lib/triton/quant_per_block_varlen.py @@ -0,0 +1,158 @@ +""" +Copyright (c) 2024 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def quant_per_block_int8_kernel( + Input, + Output, + Scale, + cu_seqlens_input, + cu_seqlens_scale, + stride_ih, + stride_in, + stride_oh, + stride_on, + sm_scale, + H: tl.constexpr, + C: tl.constexpr, + BLK: tl.constexpr, +): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + + cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b) + cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1) + + L = cu_seqlens_input_end - cu_seqlens_input_start + + if (off_blk * BLK) >= L: + return + + cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b) + + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + + input_ptrs = ( + Input + + cu_seqlens_input_start * stride_in + + off_h * stride_ih + + offs_n[:, None] * stride_in + + offs_k[None, :] + ) + output_ptrs = ( + Output + + cu_seqlens_input_start * stride_on + + off_h * stride_oh + + offs_n[:, None] * stride_on + + offs_k[None, :] + ) + scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H + + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127.0 + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + + +def per_block_int8( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + BLKQ=128, + BLKK=64, + sm_scale=None, +): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + + h_qo = q.shape[1] + h_kv = k.shape[1] + head_dim = q.shape[-1] + + b = cu_seqlens_q.shape[0] - 1 + q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + + q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ + k_scale_len = (k_batch_len + BLKK - 1) // BLKK + + cu_seqlens_q_scale = torch.nn.functional.pad( + torch.cumsum(q_scale_len, dim=0), (1, 0), value=0 + ) + cu_seqlens_k_scale = torch.nn.functional.pad( + torch.cumsum(k_scale_len, dim=0), (1, 0), value=0 + ) + + q_scale = torch.empty( + (cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32 + ) + k_scale = torch.empty( + (cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32 + ) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, + q_int8, + q_scale, + cu_seqlens_q, + cu_seqlens_q_scale, + q.stride(1), + q.stride(0), + q_int8.stride(1), + q_int8.stride(0), + sm_scale=(sm_scale * 1.44269504), + H=h_qo, + C=head_dim, + BLK=BLKQ, + ) + + grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, + k_int8, + k_scale, + cu_seqlens_k, + cu_seqlens_k_scale, + k.stride(1), + k.stride(0), + k_int8.stride(1), + k_int8.stride(0), + sm_scale=1.0, + H=h_kv, + C=head_dim, + BLK=BLKK, + ) + + return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 082df7c27..0a9b7be7f 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -46,6 +46,7 @@ from transformers.integrations.deepspeed import ( ) from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -707,6 +708,7 @@ class ModelLoader: self.model_config._attn_implementation = ( # pylint: disable=protected-access "sdpa" ) + monkeypatch_sdp_w_sage_attention() elif self.cfg.eager_attention: self.model_kwargs["attn_implementation"] = "eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access