Compare commits
1 Commits
coderabbit
...
sageattent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8acc72dd8 |
0
src/axolotl/integrations/sageattention/__init__.py
Normal file
0
src/axolotl/integrations/sageattention/__init__.py
Normal file
361
src/axolotl/integrations/sageattention/lib/core.py
Normal file
361
src/axolotl/integrations/sageattention/lib/core.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from .triton.attn_qk_int8_per_block_causal_varlen import (
|
||||
backward as sageattn_varlen_backward,
|
||||
)
|
||||
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
|
||||
from .triton.quant_per_block_varlen import (
|
||||
per_block_int8 as per_block_int8_varlen_triton,
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
cuda_archs = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
cuda_archs.append(f"sm{major}{minor}")
|
||||
return cuda_archs
|
||||
|
||||
|
||||
def sageattn_varlen(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
sm_scale: Optional[float] = None,
|
||||
smooth_k: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
q : torch.Tensor
|
||||
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||
|
||||
k : torch.Tensor
|
||||
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||
|
||||
v : torch.Tensor
|
||||
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||
|
||||
cu_seqlens_q : torch.Tensor
|
||||
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
|
||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||
|
||||
cu_seqlens_k : torch.Tensor
|
||||
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
|
||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||
|
||||
max_seqlen_q : int
|
||||
The maximum sequence length for the query tensor in the batch.
|
||||
|
||||
max_seqlen_k : int
|
||||
The maximum sequence length for the key and value tensors in the batch.
|
||||
|
||||
is_causal : bool
|
||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
|
||||
Default: False.
|
||||
|
||||
sm_scale : Optional[float]
|
||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||
|
||||
smooth_k : bool
|
||||
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
|
||||
Default: True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||
|
||||
Note
|
||||
----
|
||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
|
||||
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
|
||||
- All tensors must be on the same cuda device.
|
||||
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
||||
"""
|
||||
|
||||
dtype = q.dtype
|
||||
assert q.is_cuda, "Input tensors must be on cuda."
|
||||
assert dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
||||
assert q.device == k.device == v.device, "All tensors must be on the same device."
|
||||
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
|
||||
|
||||
head_dim = q.size(-1)
|
||||
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
|
||||
|
||||
assert (
|
||||
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
|
||||
), "Last dim of qkv must be contiguous."
|
||||
assert (
|
||||
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
|
||||
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
|
||||
|
||||
if dtype == torch.bfloat16 or dtype == torch.float32:
|
||||
v = v.to(torch.float16)
|
||||
|
||||
if smooth_k:
|
||||
km = k.mean(
|
||||
dim=0, keepdim=True
|
||||
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
|
||||
k -= km
|
||||
|
||||
(
|
||||
q_int8,
|
||||
q_scale,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
) = per_block_int8_varlen_triton(
|
||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
|
||||
)
|
||||
|
||||
o = attn_true_varlen(
|
||||
q_int8,
|
||||
k_int8,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=dtype,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class SageAttentionFunction(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
):
|
||||
"""
|
||||
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
|
||||
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||
attn_mask: Optional[Tensor], mask tensor
|
||||
dropout_p: float, dropout probability
|
||||
is_causal: bool, whether to apply causal masking
|
||||
scale: Optional[float], scaling factor for attention scores
|
||||
"""
|
||||
# Ensure inputs are contiguous
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# Handle default scale
|
||||
if scale is None:
|
||||
scale = 1.0 / (query.size(-1) ** 0.5)
|
||||
|
||||
# Save parameters needed for backward
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.attn_mask = attn_mask
|
||||
|
||||
# Prepare cumulative sequence lengths and max sequence lengths
|
||||
# Assuming batch sizes are consistent across query, key, and value
|
||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||
seq_len_k = key.shape[2]
|
||||
|
||||
# Flatten batch and head dimensions
|
||||
q = query.view(
|
||||
-1, seq_len_q, head_dim
|
||||
) # [batch_size * num_heads, seq_len_q, head_dim]
|
||||
k = key.view(-1, seq_len_k, head_dim)
|
||||
v = value.view(-1, seq_len_k, head_dim)
|
||||
|
||||
# Create cumulative sequence lengths
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size * num_heads + 1) * seq_len_q,
|
||||
seq_len_q,
|
||||
dtype=torch.int32,
|
||||
device=query.device,
|
||||
)
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size * num_heads + 1) * seq_len_k,
|
||||
seq_len_k,
|
||||
dtype=torch.int32,
|
||||
device=key.device,
|
||||
)
|
||||
max_seqlen_q = seq_len_q
|
||||
max_seqlen_k = seq_len_k
|
||||
|
||||
# Call your custom per-block int8 quantization function
|
||||
(
|
||||
q_int8,
|
||||
q_scale,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
) = per_block_int8_varlen_triton(
|
||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
|
||||
)
|
||||
|
||||
# Call your custom attention function
|
||||
if is_causal:
|
||||
output = attn_true_varlen(
|
||||
q_int8,
|
||||
k_int8,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=query.dtype,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Non-causal attention is not implemented yet.")
|
||||
|
||||
# Reshape output to match the expected shape
|
||||
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||
|
||||
# Save tensors for backward
|
||||
ctx.save_for_backward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output,
|
||||
) = ctx.saved_tensors
|
||||
|
||||
scale = ctx.scale
|
||||
is_causal = ctx.is_causal
|
||||
dropout_p = ctx.dropout_p
|
||||
attn_mask = ctx.attn_mask
|
||||
|
||||
# Flatten batch and head dimensions
|
||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||
seq_len_k = key.shape[2]
|
||||
grad_output = grad_output.contiguous()
|
||||
do = grad_output.view(-1, seq_len_q, head_dim)
|
||||
|
||||
# Compute gradients w.r.t. q, k, v
|
||||
dq, dk, dv = sageattn_varlen_backward(
|
||||
do,
|
||||
query.view(-1, seq_len_q, head_dim),
|
||||
key.view(-1, seq_len_k, head_dim),
|
||||
value.view(-1, seq_len_k, head_dim),
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
scale,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
# Reshape gradients to match the input shapes
|
||||
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||
|
||||
# Handle optional arguments
|
||||
d_attn_mask = None # Assuming attn_mask does not require gradients
|
||||
d_dropout_p = (
|
||||
None # Dropout probability is a hyperparameter, typically not optimized
|
||||
)
|
||||
d_is_causal = None # Not differentiable
|
||||
d_scale = None # If scale is a tensor and requires grad, compute its gradient
|
||||
|
||||
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
):
|
||||
"""
|
||||
Custom scaled dot product attention using SageAttentionFunction.
|
||||
"""
|
||||
return SageAttentionFunction.apply(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale
|
||||
)
|
||||
|
||||
|
||||
def monkeypatch_sdp_w_sage_attention():
|
||||
"""
|
||||
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
|
||||
"""
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
):
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
K_scale_ptr += (lo // BLOCK_N) * H
|
||||
K_ptrs += stride_kn * lo
|
||||
V_ptrs += stride_vn * lo
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||
k = tl.load(K_ptrs, mask=k_mask)
|
||||
k_scale = tl.load(K_scale_ptr)
|
||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = qk + tl.where(mask, 0, -1.0e6)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk -= m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
|
||||
p = tl.math.exp2(qk)
|
||||
l_ij = tl.sum(p, 1)
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
||||
p = p.to(tl.float16)
|
||||
|
||||
acc += tl.dot(p, v, out_dtype=tl.float16)
|
||||
m_i = m_ij
|
||||
K_ptrs += BLOCK_N * stride_kn
|
||||
K_scale_ptr += H
|
||||
V_ptrs += BLOCK_N * stride_vn
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
Q_scale,
|
||||
K_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
Out,
|
||||
stride_qh,
|
||||
stride_qn,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_oh,
|
||||
stride_on,
|
||||
H: tl.constexpr,
|
||||
num_kv_groups: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
|
||||
off_z = tl.program_id(2).to(tl.int64)
|
||||
off_h = tl.program_id(1).to(tl.int64)
|
||||
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
|
||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
|
||||
if (start_m * BLOCK_M) >= qo_len:
|
||||
return
|
||||
|
||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||
|
||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||
k_scale_offset = (
|
||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
)
|
||||
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
|
||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
Q_ptrs = (
|
||||
Q
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
Q_scale_ptr = Q_scale + q_scale_offset
|
||||
K_ptrs = (
|
||||
K
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
K_scale_ptr = K_scale + k_scale_offset
|
||||
V_ptrs = (
|
||||
V
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
O_block_ptr = (
|
||||
Out
|
||||
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
|
||||
+ offs_m[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
q_scale = tl.load(Q_scale_ptr)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
acc, l_i, _ = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
acc = acc / l_i[:, None]
|
||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
do,
|
||||
q_scale,
|
||||
k_scale,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H,
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
):
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
k += stride_kn * lo
|
||||
v += stride_vn * lo
|
||||
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||
k_curr = tl.load(k, mask=k_mask)
|
||||
v_curr = tl.load(v, mask=k_mask)
|
||||
k_scale_curr = tl.load(k_scale)
|
||||
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
|
||||
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
s = s + tl.where(mask, 0.0, -float("inf"))
|
||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||
s = s - m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||
s = s - m_ij[:, None]
|
||||
|
||||
p = tl.math.exp2(s)
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
p = p / l_i[:, None] # Normalize probabilities
|
||||
|
||||
# Compute gradients
|
||||
# Compute softmax gradient
|
||||
do_scaled = do / l_i[:, None]
|
||||
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
|
||||
dv_acc += dv_contrib
|
||||
|
||||
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
|
||||
|
||||
# Compute ds (gradient w.r.t. logits s)
|
||||
p_dp = p * dp
|
||||
sum_p_dp = tl.sum(p_dp, axis=1)
|
||||
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
|
||||
|
||||
# Compute gradients w.r.t q and k
|
||||
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
|
||||
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
|
||||
|
||||
dq_acc += dq_contrib * (q_scale * k_scale_curr)
|
||||
dk_acc += dk_contrib * (q_scale * k_scale_curr)
|
||||
|
||||
k += BLOCK_N * stride_kn
|
||||
k_scale += H
|
||||
v += BLOCK_N * stride_vn
|
||||
|
||||
return dq_acc, dk_acc, dv_acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd(
|
||||
DO,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
Q_scale,
|
||||
K_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
L,
|
||||
M,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
stride_qh,
|
||||
stride_qn,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
H: tl.constexpr,
|
||||
num_kv_groups: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_z = tl.program_id(2).to(tl.int64)
|
||||
off_h = tl.program_id(1).to(tl.int64)
|
||||
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
|
||||
if (start_m * BLOCK_M) >= qo_len:
|
||||
return
|
||||
|
||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||
|
||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||
k_scale_offset = (
|
||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
)
|
||||
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
Q_ptrs = (
|
||||
Q
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DO_ptrs = (
|
||||
DO
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
Q_scale_ptr = Q_scale + q_scale_offset
|
||||
K_ptrs = (
|
||||
K
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
K_scale_ptr = K_scale + k_scale_offset
|
||||
V_ptrs = (
|
||||
V
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DQ_ptrs = (
|
||||
DQ
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DK_ptrs = (
|
||||
DK
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
DV_ptrs = (
|
||||
DV
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
L_ptrs = L + (cu_seqlens_q_start + offs_m)
|
||||
M_ptrs = M + (cu_seqlens_q_start + offs_m)
|
||||
|
||||
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
|
||||
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
|
||||
|
||||
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
q_scale = tl.load(Q_scale_ptr)
|
||||
|
||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_ptrs,
|
||||
V_ptrs,
|
||||
do,
|
||||
q_scale,
|
||||
K_scale_ptr,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_ptrs,
|
||||
V_ptrs,
|
||||
do,
|
||||
q_scale,
|
||||
K_scale_ptr,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
|
||||
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
|
||||
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
|
||||
|
||||
|
||||
def forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=torch.float16,
|
||||
):
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
stage = 3
|
||||
|
||||
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
||||
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
_, h_qo, head_dim = q.shape
|
||||
_, h_kv, _ = k.shape
|
||||
|
||||
HEAD_DIM_K = head_dim
|
||||
num_kv_groups = h_qo // h_kv
|
||||
|
||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||
_attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
o,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(0),
|
||||
h_qo,
|
||||
num_kv_groups,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
HEAD_DIM=HEAD_DIM_K,
|
||||
STAGE=stage,
|
||||
num_warps=4 if head_dim == 64 else 8,
|
||||
num_stages=4,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
def backward(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
l,
|
||||
m,
|
||||
output_dtype=torch.float16,
|
||||
):
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
stage = 3
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
_, h_qo, head_dim = q.shape
|
||||
_, h_kv, _ = k.shape
|
||||
num_kv_groups = h_qo // h_kv
|
||||
|
||||
dq = torch.zeros_like(q, dtype=output_dtype)
|
||||
dk = torch.zeros_like(k, dtype=output_dtype)
|
||||
dv = torch.zeros_like(v, dtype=output_dtype)
|
||||
|
||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||
_attn_bwd[grid](
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
l,
|
||||
m,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
h_qo,
|
||||
num_kv_groups,
|
||||
HEAD_DIM=head_dim,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
STAGE=stage,
|
||||
num_warps=4 if head_dim == 64 else 8,
|
||||
num_stages=4,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
# class TritonAttentionFunction(torch.autograd.Function):
|
||||
# @staticmethod
|
||||
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
|
||||
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||
# return output
|
||||
#
|
||||
# @staticmethod
|
||||
# def backward(ctx, do):
|
||||
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
|
||||
# dq, dk, dv = backward(
|
||||
# do, q, k, v,
|
||||
# cu_seqlens_q, cu_seqlens_k,
|
||||
# q.shape[0], q_scale, k_scale,
|
||||
# cu_seqlens_q_scale, cu_seqlens_k_scale,
|
||||
# l, m,
|
||||
# )
|
||||
# return dq, dk, dv, None, None, None, None, None, None
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def quant_per_block_int8_kernel(
|
||||
Input,
|
||||
Output,
|
||||
Scale,
|
||||
cu_seqlens_input,
|
||||
cu_seqlens_scale,
|
||||
stride_ih,
|
||||
stride_in,
|
||||
stride_oh,
|
||||
stride_on,
|
||||
sm_scale,
|
||||
H: tl.constexpr,
|
||||
C: tl.constexpr,
|
||||
BLK: tl.constexpr,
|
||||
):
|
||||
off_blk = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
off_b = tl.program_id(2)
|
||||
|
||||
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
|
||||
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
|
||||
|
||||
L = cu_seqlens_input_end - cu_seqlens_input_start
|
||||
|
||||
if (off_blk * BLK) >= L:
|
||||
return
|
||||
|
||||
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
|
||||
|
||||
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
||||
offs_k = tl.arange(0, C)
|
||||
|
||||
input_ptrs = (
|
||||
Input
|
||||
+ cu_seqlens_input_start * stride_in
|
||||
+ off_h * stride_ih
|
||||
+ offs_n[:, None] * stride_in
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
output_ptrs = (
|
||||
Output
|
||||
+ cu_seqlens_input_start * stride_on
|
||||
+ off_h * stride_oh
|
||||
+ offs_n[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
|
||||
|
||||
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
||||
x = x.to(tl.float32)
|
||||
x *= sm_scale
|
||||
scale = tl.max(tl.abs(x)) / 127.0
|
||||
x_int8 = x / scale
|
||||
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
||||
x_int8 = x_int8.to(tl.int8)
|
||||
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
||||
tl.store(scale_ptrs, scale)
|
||||
|
||||
|
||||
def per_block_int8(
|
||||
q,
|
||||
k,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLKQ=128,
|
||||
BLKK=64,
|
||||
sm_scale=None,
|
||||
):
|
||||
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
||||
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
||||
|
||||
h_qo = q.shape[1]
|
||||
h_kv = k.shape[1]
|
||||
head_dim = q.shape[-1]
|
||||
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||
|
||||
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
|
||||
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
|
||||
|
||||
cu_seqlens_q_scale = torch.nn.functional.pad(
|
||||
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
|
||||
)
|
||||
cu_seqlens_k_scale = torch.nn.functional.pad(
|
||||
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
|
||||
)
|
||||
|
||||
q_scale = torch.empty(
|
||||
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = head_dim**-0.5
|
||||
|
||||
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
q,
|
||||
q_int8,
|
||||
q_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_q_scale,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
q_int8.stride(1),
|
||||
q_int8.stride(0),
|
||||
sm_scale=(sm_scale * 1.44269504),
|
||||
H=h_qo,
|
||||
C=head_dim,
|
||||
BLK=BLKQ,
|
||||
)
|
||||
|
||||
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
k,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_k_scale,
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
k_int8.stride(1),
|
||||
k_int8.stride(0),
|
||||
sm_scale=1.0,
|
||||
H=h_kv,
|
||||
C=head_dim,
|
||||
BLK=BLKK,
|
||||
)
|
||||
|
||||
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale
|
||||
@@ -46,6 +46,7 @@ from transformers.integrations.deepspeed import (
|
||||
)
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -707,6 +708,7 @@ class ModelLoader:
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
monkeypatch_sdp_w_sage_attention()
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
|
||||
Reference in New Issue
Block a user