[gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598)
This commit is contained in:
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
|
||||
|
||||
Fuses three operations into one kernel launch:
|
||||
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
|
||||
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
3. (optional) RMSNorm without scale (for v_norm)
|
||||
|
||||
This eliminates two intermediate tensor materializations per Q/K path;
|
||||
churn from rotate_half / apply_rotary_pos_emb.
|
||||
|
||||
Shapes:
|
||||
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
|
||||
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
|
||||
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
|
||||
sin: (rows, head_dim) — same as cos
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from liger_kernel.ops.utils import (
|
||||
calculate_settings,
|
||||
compare_version,
|
||||
ensure_contiguous,
|
||||
torch_to_triton_dtype,
|
||||
)
|
||||
from liger_kernel.utils import is_npu_available
|
||||
|
||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||
try:
|
||||
from triton.language.extra.libdevice import rsqrt
|
||||
except ModuleNotFoundError:
|
||||
from triton.language.extra.cuda.libdevice import rsqrt
|
||||
else:
|
||||
from triton.language.math import rsqrt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_rope_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
W_ptr,
|
||||
COS_ptr,
|
||||
COS_row_stride,
|
||||
SIN_ptr,
|
||||
SIN_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused forward:
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||
|
||||
rotate_half swaps first/second halves and negates the first:
|
||||
rotate_half([a, b]) = [-b, a]
|
||||
|
||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||
cs_row_idx = row_idx // n_heads
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
|
||||
# Load input row
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
# RMSNorm: compute 1/rms
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
# Normalize
|
||||
X_norm = X_fp32 * rstd
|
||||
|
||||
# Apply weight if present (with_scale=True)
|
||||
if HAS_WEIGHT:
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_norm = X_norm * W_row
|
||||
|
||||
# RoPE: load cos/sin (broadcast across heads)
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
sin_row = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
|
||||
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||
# for col >= half_dim, take X_norm[col - half_dim]
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
X_rot = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||
).to(tl.float32)
|
||||
# Re-normalize the rotated values
|
||||
X_rot_norm = X_rot * rstd
|
||||
if HAS_WEIGHT:
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||
tl.float32
|
||||
)
|
||||
X_rot_norm = X_rot_norm * W_rot
|
||||
|
||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||
X_rot_norm = X_rot_norm * sign
|
||||
|
||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
Y_row = X_norm * cos_row + X_rot_norm * sin_row
|
||||
|
||||
tl.store(
|
||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||
Y_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_rope_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
W_ptr,
|
||||
COS_ptr,
|
||||
COS_row_stride,
|
||||
SIN_ptr,
|
||||
SIN_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
dW_ptr,
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = RoPE(RMSNorm(X, W))
|
||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
row_start = row_block_id * rows_per_program
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
if HAS_WEIGHT:
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
|
||||
for row_idx in range(row_start, row_end):
|
||||
cs_row_idx = row_idx // n_heads
|
||||
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
|
||||
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||
#
|
||||
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||
# This is equivalent to rotating (dY * sin) because the rotation
|
||||
# just permutes which elements are multiplied.
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
dY_rot = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
sin_rot = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
|
||||
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||
|
||||
# Pre-weight normalized: n = rstd * x
|
||||
n = X_row * rstd
|
||||
|
||||
if HAS_WEIGHT:
|
||||
dW_acc += dN * n
|
||||
dm = dN * W_row
|
||||
else:
|
||||
dm = dN
|
||||
|
||||
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
|
||||
dot_dm_x = tl.sum(dm * X_row, axis=0)
|
||||
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
|
||||
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||
dX_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if HAS_WEIGHT:
|
||||
tl.store(
|
||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||
dW_acc,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
"""
|
||||
Args:
|
||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||
W: (head_dim,) or None — RMSNorm weight
|
||||
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
eps: float
|
||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||
Returns:
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||
"""
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
has_weight = W is not None
|
||||
|
||||
Y = torch.empty_like(X)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_rope_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
W if has_weight else X, # dummy pointer when no weight
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT=has_weight,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||
n_rows, n_cols = dY.shape
|
||||
has_weight = W is not None
|
||||
|
||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||
rows_per_program = math.ceil(n_rows / sm_count)
|
||||
|
||||
dX = torch.empty_like(X)
|
||||
|
||||
if has_weight:
|
||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
|
||||
else:
|
||||
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_rope_backward_kernel[(sm_count,)](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
W if has_weight else X, # dummy
|
||||
cos,
|
||||
cos.stride(0),
|
||||
sin,
|
||||
sin.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
_dW,
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT=has_weight,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
|
||||
return dX, dW
|
||||
|
||||
|
||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||
"""
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, head_dim) — broadcast across heads
|
||||
sin: (B*S, head_dim) — broadcast across heads
|
||||
n_heads: int
|
||||
"""
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||
X,
|
||||
W,
|
||||
cos,
|
||||
sin,
|
||||
eps,
|
||||
n_heads,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.n_heads = n_heads
|
||||
ctx.has_weight = W is not None
|
||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, W, cos, sin, RSTD = ctx.saved_tensors
|
||||
dX, dW = rms_norm_rope_backward(
|
||||
dY,
|
||||
X,
|
||||
W,
|
||||
cos,
|
||||
sin,
|
||||
RSTD,
|
||||
ctx.n_heads,
|
||||
ctx.BLOCK_SIZE,
|
||||
ctx.num_warps,
|
||||
)
|
||||
return dX, dW, None, None, None, None
|
||||
|
||||
|
||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
Apply fused RMSNorm + RoPE.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
eps: float — RMSNorm epsilon
|
||||
|
||||
Returns:
|
||||
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
|
||||
"""
|
||||
shape = x.shape # (B, S, H, D)
|
||||
B, S, H, D = shape
|
||||
# Flatten to 2D: (B*S*H, D)
|
||||
x_flat = x.reshape(-1, D).contiguous()
|
||||
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||
# by dividing the row_idx by H to get the cos/sin row
|
||||
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||
return y_flat.view(shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""RMSNorm without scale weight: y = x / rms(x)"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
Y_row = X_fp32 * rstd
|
||||
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_noscale_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Backward for y = x * rstd (no weight)."""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||
).to(tl.float32)
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
|
||||
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
|
||||
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
|
||||
)
|
||||
|
||||
|
||||
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
|
||||
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, eps):
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
Y = torch.empty_like(X)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
_rms_norm_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, RSTD)
|
||||
ctx.n_cols = n_cols
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, RSTD = ctx.saved_tensors
|
||||
n_rows = X.shape[0]
|
||||
dX = torch.empty_like(X)
|
||||
_rms_norm_noscale_backward_kernel[(n_rows,)](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
ctx.n_cols,
|
||||
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||
num_warps=ctx.num_warps,
|
||||
)
|
||||
return dX, None
|
||||
|
||||
|
||||
def fused_rms_norm_noscale(x, eps=1e-6):
|
||||
"""
|
||||
RMSNorm without scale for v_norm.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim)
|
||||
Returns:
|
||||
y: same shape, normalized
|
||||
"""
|
||||
shape = x.shape
|
||||
x_flat = x.reshape(-1, shape[-1])
|
||||
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
|
||||
return y_flat.view(shape)
|
||||
@@ -624,7 +624,14 @@ class ModelLoader:
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.attn_implementation:
|
||||
if self.cfg.gemma4_hybrid_attn_impl:
|
||||
# Load model with flash_attention_2 for sliding window layers;
|
||||
# global layers will be patched to sdpa post-load.
|
||||
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
self.model_config._attn_implementation = "flash_attention_2"
|
||||
# Set flash_attention so multipack/sample_packing patches activate
|
||||
self.cfg.flash_attention = True
|
||||
elif self.cfg.attn_implementation:
|
||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||
elif self.cfg.flex_attention:
|
||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||
|
||||
@@ -156,6 +156,7 @@ class PatchManager:
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
self._apply_gemma_hybrid_attention(model)
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
@@ -165,6 +166,72 @@ class PatchManager:
|
||||
self._apply_lora_kernel_patch(model)
|
||||
self._apply_scaling_softmax_patch(model)
|
||||
|
||||
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
|
||||
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
|
||||
|
||||
Gemma 4 has global (full_attention) layers with head_dim=512
|
||||
which exceeds flash attention's supported size. This patch loads the model
|
||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||
"""
|
||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||
return
|
||||
|
||||
import copy
|
||||
|
||||
# Navigate to the module that has 'layers' - varies by model structure:
|
||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||
layers = None
|
||||
config_source = None
|
||||
for candidate in [model, getattr(model, "model", None)]:
|
||||
if candidate is None:
|
||||
continue
|
||||
# Check direct layers
|
||||
if hasattr(candidate, "layers"):
|
||||
layers = candidate.layers
|
||||
config_source = candidate
|
||||
break
|
||||
# Check language_model.layers (multimodal wrapper)
|
||||
lang_model = getattr(candidate, "language_model", None)
|
||||
if lang_model is not None and hasattr(lang_model, "layers"):
|
||||
layers = lang_model.layers
|
||||
config_source = lang_model
|
||||
break
|
||||
|
||||
if layers is None:
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
config = getattr(config_source, "config", self.model_config)
|
||||
layer_types = getattr(config, "layer_types", None)
|
||||
if layer_types is None:
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
|
||||
"This feature requires a model with mixed sliding/global attention layers."
|
||||
)
|
||||
return
|
||||
|
||||
patched_count = 0
|
||||
for layer_idx, layer in enumerate(layers):
|
||||
if layer_types[layer_idx] != "sliding_attention":
|
||||
# Global / full_attention layer - use SDPA instead of FA2
|
||||
attn_module = getattr(layer, "self_attn", None)
|
||||
if attn_module is not None and hasattr(attn_module, "config"):
|
||||
sdpa_config = copy.copy(attn_module.config)
|
||||
sdpa_config._attn_implementation = "sdpa"
|
||||
attn_module.config = sdpa_config
|
||||
patched_count += 1
|
||||
|
||||
LOG.info(
|
||||
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
|
||||
"(remaining %d sliding layers use flash_attention_2)",
|
||||
patched_count,
|
||||
len(layers) - patched_count,
|
||||
)
|
||||
|
||||
def _apply_flash_attention_patches(self):
|
||||
"""Apply patches related to Flash Attention."""
|
||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||
@@ -324,6 +391,13 @@ class PatchManager:
|
||||
|
||||
patch_qwen3_5_vlm_flash_attention()
|
||||
|
||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
patch_gemma4_fused_attn()
|
||||
|
||||
@staticmethod
|
||||
def _fix_nemotron_h_conversion_mapping():
|
||||
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||
|
||||
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Gemma 4 fused attention monkeypatch.
|
||||
|
||||
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
|
||||
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
|
||||
|
||||
Usage:
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||
patch_gemma4_fused_attn()
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _make_fused_forward(original_forward):
|
||||
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import (
|
||||
fused_rms_norm_noscale,
|
||||
fused_rms_norm_rope,
|
||||
)
|
||||
|
||||
def fused_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
eager_attention_forward,
|
||||
)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
eps = self.config.rms_norm_eps
|
||||
|
||||
cos, sin = position_embeddings
|
||||
|
||||
# ---- Projections ----
|
||||
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
|
||||
has_lora_qkv = hasattr(self, "apply_qkv")
|
||||
|
||||
if has_lora_qkv:
|
||||
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||
query_states = query_states.view(hidden_shape)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# ---- Q path: fused q_norm + RoPE ----
|
||||
query_states = fused_rms_norm_rope(
|
||||
query_states,
|
||||
self.q_norm.weight,
|
||||
cos,
|
||||
sin,
|
||||
eps=eps,
|
||||
)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# ---- K/V path ----
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
if has_lora_qkv:
|
||||
# apply_qkv already computed k/v projections
|
||||
key_states = key_states.view(hidden_shape)
|
||||
value_states = (
|
||||
value_states.view(hidden_shape)
|
||||
if self.v_proj is not None
|
||||
else key_states
|
||||
)
|
||||
else:
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states).view(hidden_shape)
|
||||
if self.v_proj is not None
|
||||
else key_states
|
||||
)
|
||||
|
||||
# Fused k_norm + RoPE
|
||||
key_states = fused_rms_norm_rope(
|
||||
key_states,
|
||||
self.k_norm.weight,
|
||||
cos,
|
||||
sin,
|
||||
eps=eps,
|
||||
)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
# Fused v_norm (no scale, no RoPE)
|
||||
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||
self.config._attn_implementation
|
||||
]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=self.attention_dropout if self.training else 0.0,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
return fused_forward
|
||||
|
||||
|
||||
def patch_gemma4_fused_attn():
|
||||
"""
|
||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
||||
"""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||
|
||||
original_forward = Gemma4TextAttention.forward
|
||||
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||
|
||||
logger.info(
|
||||
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||
)
|
||||
@@ -777,6 +777,15 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
gemma4_hybrid_attn_impl: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
|
||||
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
|
||||
"exceeds flash attention's supported size."
|
||||
},
|
||||
)
|
||||
|
||||
experts_implementation: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
226
tests/kernels/test_gemma4_fused_rope.py
Normal file
226
tests/kernels/test_gemma4_fused_rope.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
|
||||
|
||||
Tests forward and backward against the reference Gemma4 implementation
|
||||
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
|
||||
(head_dim=256) and global attention (head_dim=512) layer configurations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Skip entire module if no CUDA
|
||||
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
||||
|
||||
|
||||
def _reference_norm_rope(x, weight, cos, sin, eps):
|
||||
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4RMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
D = x.shape[-1]
|
||||
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||
norm.weight.data.copy_(weight)
|
||||
normed = norm(x)
|
||||
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||
|
||||
|
||||
def _reference_norm_noscale(x, eps):
|
||||
"""Reference: Gemma4RMSNorm with_scale=False."""
|
||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||
|
||||
D = x.shape[-1]
|
||||
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
|
||||
return norm(x)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
(2, 64, 32, 256), # sliding window layer shape
|
||||
(2, 64, 4, 512), # global attention layer shape
|
||||
(1, 128, 16, 256), # different batch/seq
|
||||
(1, 1, 1, 8), # minimal size
|
||||
],
|
||||
ids=["sliding_256", "global_512", "varied", "minimal"],
|
||||
)
|
||||
def shapes(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
|
||||
def dtype(request):
|
||||
return request.param
|
||||
|
||||
|
||||
class TestFusedRMSNormRoPEForward:
|
||||
"""Forward pass correctness."""
|
||||
|
||||
def test_matches_reference(self, shapes, dtype):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = shapes
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||
weight = torch.randn(D, device="cuda", dtype=dtype)
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||
|
||||
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||
|
||||
def test_output_shape(self, shapes):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = shapes
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
|
||||
class TestFusedRMSNormRoPEBackward:
|
||||
"""Backward pass correctness via gradient comparison."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D",
|
||||
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||
ids=["sliding_256", "global_512"],
|
||||
)
|
||||
def test_x_grad_matches_reference(self, B, S, H, D):
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4RMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference backward
|
||||
x_ref = torch.randn(
|
||||
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||
norm_ref.weight.data.copy_(weight_init)
|
||||
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
|
||||
y_ref.sum().backward()
|
||||
|
||||
# Fused backward
|
||||
x_fused = x_ref.data.clone().requires_grad_(True)
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||
y_fused.sum().backward()
|
||||
|
||||
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"B,S,H,D",
|
||||
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||
ids=["sliding_256", "global_512"],
|
||||
)
|
||||
def test_weight_grad_matches_reference(self, B, S, H, D):
|
||||
from transformers.models.gemma4.modeling_gemma4 import (
|
||||
Gemma4RMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
eps = 1e-6
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Reference
|
||||
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||
apply_rotary_pos_emb(
|
||||
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
|
||||
).sum().backward()
|
||||
|
||||
# Fused
|
||||
w_fused = weight_init.clone().requires_grad_(True)
|
||||
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||
|
||||
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||
w_fused.grad.flatten().float(),
|
||||
norm_ref.weight.grad.flatten().float(),
|
||||
dim=0,
|
||||
)
|
||||
assert cos_sim_w > 0.995, (
|
||||
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
|
||||
)
|
||||
|
||||
def test_grad_flows(self):
|
||||
"""Verify gradients are non-zero and finite."""
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||
|
||||
B, S, H, D = 1, 16, 4, 64
|
||||
x = torch.randn(
|
||||
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
||||
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
|
||||
y.sum().backward()
|
||||
|
||||
assert x.grad is not None, "x.grad is None"
|
||||
assert w.grad is not None, "w.grad is None"
|
||||
assert x.grad.isfinite().all(), "x.grad has non-finite values"
|
||||
assert w.grad.isfinite().all(), "w.grad has non-finite values"
|
||||
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
|
||||
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||
|
||||
|
||||
class TestFusedRMSNormNoScale:
|
||||
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||
|
||||
def test_forward_matches_reference(self, shapes, dtype):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||
|
||||
B, S, H, D = shapes
|
||||
eps = 1e-6
|
||||
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||
|
||||
y_ref = _reference_norm_noscale(x.clone(), eps)
|
||||
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(
|
||||
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||
)
|
||||
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||
|
||||
def test_backward_flows(self):
|
||||
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||
|
||||
x = torch.randn(
|
||||
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
y = fused_rms_norm_noscale(x, eps=1e-6)
|
||||
y.sum().backward()
|
||||
|
||||
assert x.grad is not None
|
||||
assert x.grad.isfinite().all()
|
||||
assert x.grad.abs().sum() > 0
|
||||
Reference in New Issue
Block a user