[gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598)
This commit is contained in:
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
529
src/axolotl/kernels/gemma4_fused_rope.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
"""
|
||||||
|
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
|
||||||
|
|
||||||
|
Fuses three operations into one kernel launch:
|
||||||
|
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
|
||||||
|
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||||
|
3. (optional) RMSNorm without scale (for v_norm)
|
||||||
|
|
||||||
|
This eliminates two intermediate tensor materializations per Q/K path;
|
||||||
|
churn from rotate_half / apply_rotary_pos_emb.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
|
||||||
|
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
|
||||||
|
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
|
||||||
|
sin: (rows, head_dim) — same as cos
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from liger_kernel.ops.utils import (
|
||||||
|
calculate_settings,
|
||||||
|
compare_version,
|
||||||
|
ensure_contiguous,
|
||||||
|
torch_to_triton_dtype,
|
||||||
|
)
|
||||||
|
from liger_kernel.utils import is_npu_available
|
||||||
|
|
||||||
|
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||||
|
try:
|
||||||
|
from triton.language.extra.libdevice import rsqrt
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from triton.language.extra.cuda.libdevice import rsqrt
|
||||||
|
else:
|
||||||
|
from triton.language.math import rsqrt
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_rope_forward_kernel(
|
||||||
|
Y_ptr,
|
||||||
|
Y_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
W_ptr,
|
||||||
|
COS_ptr,
|
||||||
|
COS_row_stride,
|
||||||
|
SIN_ptr,
|
||||||
|
SIN_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
n_heads,
|
||||||
|
eps,
|
||||||
|
HAS_WEIGHT: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fused forward:
|
||||||
|
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||||
|
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||||
|
|
||||||
|
rotate_half swaps first/second halves and negates the first:
|
||||||
|
rotate_half([a, b]) = [-b, a]
|
||||||
|
|
||||||
|
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||||
|
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||||
|
cs_row_idx = row_idx // n_heads
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
half_dim = n_cols // 2
|
||||||
|
|
||||||
|
# Load input row
|
||||||
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
X_dtype = X_row.dtype
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
|
||||||
|
# RMSNorm: compute 1/rms
|
||||||
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
|
rstd = rsqrt(mean_sq + eps)
|
||||||
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
X_norm = X_fp32 * rstd
|
||||||
|
|
||||||
|
# Apply weight if present (with_scale=True)
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||||
|
X_norm = X_norm * W_row
|
||||||
|
|
||||||
|
# RoPE: load cos/sin (broadcast across heads)
|
||||||
|
cos_row = tl.load(
|
||||||
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
sin_row = tl.load(
|
||||||
|
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||||
|
# for col >= half_dim, take X_norm[col - half_dim]
|
||||||
|
rot_offsets = tl.where(
|
||||||
|
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||||
|
)
|
||||||
|
rot_mask = rot_offsets < n_cols
|
||||||
|
X_rot = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
# Re-normalize the rotated values
|
||||||
|
X_rot_norm = X_rot * rstd
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||||
|
tl.float32
|
||||||
|
)
|
||||||
|
X_rot_norm = X_rot_norm * W_rot
|
||||||
|
|
||||||
|
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||||
|
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||||
|
X_rot_norm = X_rot_norm * sign
|
||||||
|
|
||||||
|
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||||
|
Y_row = X_norm * cos_row + X_rot_norm * sin_row
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||||
|
Y_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_rope_backward_kernel(
|
||||||
|
dY_ptr,
|
||||||
|
dY_row_stride,
|
||||||
|
dX_ptr,
|
||||||
|
dX_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
X_dtype: tl.constexpr,
|
||||||
|
W_ptr,
|
||||||
|
COS_ptr,
|
||||||
|
COS_row_stride,
|
||||||
|
SIN_ptr,
|
||||||
|
SIN_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
dW_ptr,
|
||||||
|
dW_row_stride,
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
n_heads,
|
||||||
|
rows_per_program,
|
||||||
|
HAS_WEIGHT: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Backward for Y = RoPE(RMSNorm(X, W))
|
||||||
|
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||||
|
"""
|
||||||
|
row_block_id = tl.program_id(0).to(tl.int64)
|
||||||
|
row_start = row_block_id * rows_per_program
|
||||||
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
half_dim = n_cols // 2
|
||||||
|
|
||||||
|
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||||
|
|
||||||
|
for row_idx in range(row_start, row_end):
|
||||||
|
cs_row_idx = row_idx // n_heads
|
||||||
|
|
||||||
|
dY_row = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
X_row = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
|
cos_row = tl.load(
|
||||||
|
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||||
|
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||||
|
#
|
||||||
|
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||||
|
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||||
|
# This is equivalent to rotating (dY * sin) because the rotation
|
||||||
|
# just permutes which elements are multiplied.
|
||||||
|
rot_offsets = tl.where(
|
||||||
|
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||||
|
)
|
||||||
|
rot_mask = rot_offsets < n_cols
|
||||||
|
dY_rot = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||||
|
mask=rot_mask & mask,
|
||||||
|
other=0,
|
||||||
|
).to(tl.float32)
|
||||||
|
sin_rot = tl.load(
|
||||||
|
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||||
|
mask=rot_mask & mask,
|
||||||
|
other=0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||||
|
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||||
|
|
||||||
|
# Pre-weight normalized: n = rstd * x
|
||||||
|
n = X_row * rstd
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
dW_acc += dN * n
|
||||||
|
dm = dN * W_row
|
||||||
|
else:
|
||||||
|
dm = dN
|
||||||
|
|
||||||
|
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
|
||||||
|
dot_dm_x = tl.sum(dm * X_row, axis=0)
|
||||||
|
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||||
|
dX_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if HAS_WEIGHT:
|
||||||
|
tl.store(
|
||||||
|
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||||
|
dW_acc,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||||
|
W: (head_dim,) or None — RMSNorm weight
|
||||||
|
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||||
|
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||||
|
eps: float
|
||||||
|
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||||
|
Returns:
|
||||||
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
"""
|
||||||
|
n_rows, n_cols = X.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
has_weight = W is not None
|
||||||
|
|
||||||
|
Y = torch.empty_like(X)
|
||||||
|
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_rope_forward_kernel[(n_rows,)](
|
||||||
|
Y,
|
||||||
|
Y.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
W if has_weight else X, # dummy pointer when no weight
|
||||||
|
cos,
|
||||||
|
cos.stride(0),
|
||||||
|
sin,
|
||||||
|
sin.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
n_cols,
|
||||||
|
n_heads,
|
||||||
|
eps,
|
||||||
|
HAS_WEIGHT=has_weight,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||||
|
n_rows, n_cols = dY.shape
|
||||||
|
has_weight = W is not None
|
||||||
|
|
||||||
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||||
|
rows_per_program = math.ceil(n_rows / sm_count)
|
||||||
|
|
||||||
|
dX = torch.empty_like(X)
|
||||||
|
|
||||||
|
if has_weight:
|
||||||
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
|
||||||
|
else:
|
||||||
|
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_rope_backward_kernel[(sm_count,)](
|
||||||
|
dY,
|
||||||
|
dY.stride(0),
|
||||||
|
dX,
|
||||||
|
dX.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
torch_to_triton_dtype[X.dtype],
|
||||||
|
W if has_weight else X, # dummy
|
||||||
|
cos,
|
||||||
|
cos.stride(0),
|
||||||
|
sin,
|
||||||
|
sin.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
_dW,
|
||||||
|
_dW.stride(0),
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
n_heads,
|
||||||
|
rows_per_program,
|
||||||
|
HAS_WEIGHT=has_weight,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
|
||||||
|
return dX, dW
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||||
|
"""
|
||||||
|
X: (B*S*H, head_dim)
|
||||||
|
W: (head_dim,) or None
|
||||||
|
cos: (B*S, head_dim) — broadcast across heads
|
||||||
|
sin: (B*S, head_dim) — broadcast across heads
|
||||||
|
n_heads: int
|
||||||
|
"""
|
||||||
|
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps,
|
||||||
|
n_heads,
|
||||||
|
)
|
||||||
|
ctx.eps = eps
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.n_heads = n_heads
|
||||||
|
ctx.has_weight = W is not None
|
||||||
|
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def backward(ctx, dY):
|
||||||
|
X, W, cos, sin, RSTD = ctx.saved_tensors
|
||||||
|
dX, dW = rms_norm_rope_backward(
|
||||||
|
dY,
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
RSTD,
|
||||||
|
ctx.n_heads,
|
||||||
|
ctx.BLOCK_SIZE,
|
||||||
|
ctx.num_warps,
|
||||||
|
)
|
||||||
|
return dX, dW, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Apply fused RMSNorm + RoPE.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||||
|
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||||
|
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||||
|
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||||
|
eps: float — RMSNorm epsilon
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
|
||||||
|
"""
|
||||||
|
shape = x.shape # (B, S, H, D)
|
||||||
|
B, S, H, D = shape
|
||||||
|
# Flatten to 2D: (B*S*H, D)
|
||||||
|
x_flat = x.reshape(-1, D).contiguous()
|
||||||
|
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||||
|
# by dividing the row_idx by H to get the cos/sin row
|
||||||
|
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||||
|
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||||
|
|
||||||
|
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||||
|
return y_flat.view(shape)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_forward_kernel(
|
||||||
|
Y_ptr,
|
||||||
|
Y_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""RMSNorm without scale weight: y = x / rms(x)"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
X_dtype = X_row.dtype
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
|
||||||
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
|
rstd = rsqrt(mean_sq + eps)
|
||||||
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
|
|
||||||
|
Y_row = X_fp32 * rstd
|
||||||
|
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_noscale_backward_kernel(
|
||||||
|
dY_ptr,
|
||||||
|
dY_row_stride,
|
||||||
|
dX_ptr,
|
||||||
|
dX_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
X_dtype: tl.constexpr,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Backward for y = x * rstd (no weight)."""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
dY_row = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
X_row = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
|
||||||
|
).to(tl.float32)
|
||||||
|
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
|
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
|
||||||
|
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
|
||||||
|
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def forward(ctx, X, eps):
|
||||||
|
n_rows, n_cols = X.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
Y = torch.empty_like(X)
|
||||||
|
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
_rms_norm_forward_kernel[(n_rows,)](
|
||||||
|
Y,
|
||||||
|
Y.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.save_for_backward(X, RSTD)
|
||||||
|
ctx.n_cols = n_cols
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def backward(ctx, dY):
|
||||||
|
X, RSTD = ctx.saved_tensors
|
||||||
|
n_rows = X.shape[0]
|
||||||
|
dX = torch.empty_like(X)
|
||||||
|
_rms_norm_noscale_backward_kernel[(n_rows,)](
|
||||||
|
dY,
|
||||||
|
dY.stride(0),
|
||||||
|
dX,
|
||||||
|
dX.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
torch_to_triton_dtype[X.dtype],
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
ctx.n_cols,
|
||||||
|
BLOCK_SIZE=ctx.BLOCK_SIZE,
|
||||||
|
num_warps=ctx.num_warps,
|
||||||
|
)
|
||||||
|
return dX, None
|
||||||
|
|
||||||
|
|
||||||
|
def fused_rms_norm_noscale(x, eps=1e-6):
|
||||||
|
"""
|
||||||
|
RMSNorm without scale for v_norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: (batch, seq_len, num_heads, head_dim)
|
||||||
|
Returns:
|
||||||
|
y: same shape, normalized
|
||||||
|
"""
|
||||||
|
shape = x.shape
|
||||||
|
x_flat = x.reshape(-1, shape[-1])
|
||||||
|
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
|
||||||
|
return y_flat.view(shape)
|
||||||
@@ -624,7 +624,14 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.attn_implementation:
|
if self.cfg.gemma4_hybrid_attn_impl:
|
||||||
|
# Load model with flash_attention_2 for sliding window layers;
|
||||||
|
# global layers will be patched to sdpa post-load.
|
||||||
|
self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
|
self.model_config._attn_implementation = "flash_attention_2"
|
||||||
|
# Set flash_attention so multipack/sample_packing patches activate
|
||||||
|
self.cfg.flash_attention = True
|
||||||
|
elif self.cfg.attn_implementation:
|
||||||
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
|
||||||
elif self.cfg.flex_attention:
|
elif self.cfg.flex_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "flex_attention"
|
self.model_kwargs["attn_implementation"] = "flex_attention"
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ class PatchManager:
|
|||||||
# which would clobber any earlier fix.
|
# which would clobber any earlier fix.
|
||||||
self._fix_nemotron_h_conversion_mapping()
|
self._fix_nemotron_h_conversion_mapping()
|
||||||
|
|
||||||
|
self._apply_gemma_hybrid_attention(model)
|
||||||
self._finalize_moe_expert_quantization(model)
|
self._finalize_moe_expert_quantization(model)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
@@ -165,6 +166,72 @@ class PatchManager:
|
|||||||
self._apply_lora_kernel_patch(model)
|
self._apply_lora_kernel_patch(model)
|
||||||
self._apply_scaling_softmax_patch(model)
|
self._apply_scaling_softmax_patch(model)
|
||||||
|
|
||||||
|
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
|
||||||
|
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
|
||||||
|
|
||||||
|
Gemma 4 has global (full_attention) layers with head_dim=512
|
||||||
|
which exceeds flash attention's supported size. This patch loads the model
|
||||||
|
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||||
|
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||||
|
"""
|
||||||
|
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||||
|
return
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# Navigate to the module that has 'layers' - varies by model structure:
|
||||||
|
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||||
|
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||||
|
layers = None
|
||||||
|
config_source = None
|
||||||
|
for candidate in [model, getattr(model, "model", None)]:
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
# Check direct layers
|
||||||
|
if hasattr(candidate, "layers"):
|
||||||
|
layers = candidate.layers
|
||||||
|
config_source = candidate
|
||||||
|
break
|
||||||
|
# Check language_model.layers (multimodal wrapper)
|
||||||
|
lang_model = getattr(candidate, "language_model", None)
|
||||||
|
if lang_model is not None and hasattr(lang_model, "layers"):
|
||||||
|
layers = lang_model.layers
|
||||||
|
config_source = lang_model
|
||||||
|
break
|
||||||
|
|
||||||
|
if layers is None:
|
||||||
|
LOG.warning(
|
||||||
|
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
config = getattr(config_source, "config", self.model_config)
|
||||||
|
layer_types = getattr(config, "layer_types", None)
|
||||||
|
if layer_types is None:
|
||||||
|
LOG.warning(
|
||||||
|
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
|
||||||
|
"This feature requires a model with mixed sliding/global attention layers."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
patched_count = 0
|
||||||
|
for layer_idx, layer in enumerate(layers):
|
||||||
|
if layer_types[layer_idx] != "sliding_attention":
|
||||||
|
# Global / full_attention layer - use SDPA instead of FA2
|
||||||
|
attn_module = getattr(layer, "self_attn", None)
|
||||||
|
if attn_module is not None and hasattr(attn_module, "config"):
|
||||||
|
sdpa_config = copy.copy(attn_module.config)
|
||||||
|
sdpa_config._attn_implementation = "sdpa"
|
||||||
|
attn_module.config = sdpa_config
|
||||||
|
patched_count += 1
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
|
||||||
|
"(remaining %d sliding layers use flash_attention_2)",
|
||||||
|
patched_count,
|
||||||
|
len(layers) - patched_count,
|
||||||
|
)
|
||||||
|
|
||||||
def _apply_flash_attention_patches(self):
|
def _apply_flash_attention_patches(self):
|
||||||
"""Apply patches related to Flash Attention."""
|
"""Apply patches related to Flash Attention."""
|
||||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||||
@@ -324,6 +391,13 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_qwen3_5_vlm_flash_attention()
|
patch_qwen3_5_vlm_flash_attention()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||||
|
patch_gemma4_fused_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_nemotron_h_conversion_mapping():
|
def _fix_nemotron_h_conversion_mapping():
|
||||||
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
"""Remove the spurious embedding→embeddings WeightRenaming from the
|
||||||
|
|||||||
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
147
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Gemma 4 fused attention monkeypatch.
|
||||||
|
|
||||||
|
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
|
||||||
|
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||||
|
patch_gemma4_fused_attn()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fused_forward(original_forward):
|
||||||
|
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import (
|
||||||
|
fused_rms_norm_noscale,
|
||||||
|
fused_rms_norm_rope,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fused_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None,
|
||||||
|
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
past_key_values=None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
eager_attention_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
eps = self.config.rms_norm_eps
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
|
||||||
|
# ---- Projections ----
|
||||||
|
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
|
||||||
|
has_lora_qkv = hasattr(self, "apply_qkv")
|
||||||
|
|
||||||
|
if has_lora_qkv:
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||||
|
query_states = query_states.view(hidden_shape)
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states).view(hidden_shape)
|
||||||
|
|
||||||
|
# ---- Q path: fused q_norm + RoPE ----
|
||||||
|
query_states = fused_rms_norm_rope(
|
||||||
|
query_states,
|
||||||
|
self.q_norm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# ---- K/V path ----
|
||||||
|
if self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||||
|
key_states = key_states.to(query_states.device)
|
||||||
|
value_states = value_states.to(query_states.device)
|
||||||
|
else:
|
||||||
|
if has_lora_qkv:
|
||||||
|
# apply_qkv already computed k/v projections
|
||||||
|
key_states = key_states.view(hidden_shape)
|
||||||
|
value_states = (
|
||||||
|
value_states.view(hidden_shape)
|
||||||
|
if self.v_proj is not None
|
||||||
|
else key_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key_states = self.k_proj(hidden_states).view(hidden_shape)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states).view(hidden_shape)
|
||||||
|
if self.v_proj is not None
|
||||||
|
else key_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fused k_norm + RoPE
|
||||||
|
key_states = fused_rms_norm_rope(
|
||||||
|
key_states,
|
||||||
|
self.k_norm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
|
||||||
|
# Fused v_norm (no scale, no RoPE)
|
||||||
|
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||||
|
key_states, value_states = past_key_values.update(
|
||||||
|
key_states, value_states, self.layer_idx
|
||||||
|
)
|
||||||
|
if self.store_full_length_kv:
|
||||||
|
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||||
|
self.config._attn_implementation
|
||||||
|
]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
dropout=self.attention_dropout if self.training else 0.0,
|
||||||
|
scaling=self.scaling,
|
||||||
|
sliding_window=self.sliding_window,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
return fused_forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma4_fused_attn():
|
||||||
|
"""
|
||||||
|
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
||||||
|
"""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||||
|
|
||||||
|
original_forward = Gemma4TextAttention.forward
|
||||||
|
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||||
|
)
|
||||||
@@ -777,6 +777,15 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
gemma4_hybrid_attn_impl: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
|
||||||
|
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
|
||||||
|
"exceeds flash attention's supported size."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
experts_implementation: str | None = Field(
|
experts_implementation: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
226
tests/kernels/test_gemma4_fused_rope.py
Normal file
226
tests/kernels/test_gemma4_fused_rope.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""
|
||||||
|
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
|
||||||
|
|
||||||
|
Tests forward and backward against the reference Gemma4 implementation
|
||||||
|
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
|
||||||
|
(head_dim=256) and global attention (head_dim=512) layer configurations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
# Skip entire module if no CUDA
|
||||||
|
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_norm_rope(x, weight, cos, sin, eps):
|
||||||
|
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
D = x.shape[-1]
|
||||||
|
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
||||||
|
norm.weight.data.copy_(weight)
|
||||||
|
normed = norm(x)
|
||||||
|
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_norm_noscale(x, eps):
|
||||||
|
"""Reference: Gemma4RMSNorm with_scale=False."""
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
||||||
|
|
||||||
|
D = x.shape[-1]
|
||||||
|
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
|
||||||
|
return norm(x)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
(2, 64, 32, 256), # sliding window layer shape
|
||||||
|
(2, 64, 4, 512), # global attention layer shape
|
||||||
|
(1, 128, 16, 256), # different batch/seq
|
||||||
|
(1, 1, 1, 8), # minimal size
|
||||||
|
],
|
||||||
|
ids=["sliding_256", "global_512", "varied", "minimal"],
|
||||||
|
)
|
||||||
|
def shapes(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
|
||||||
|
def dtype(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormRoPEForward:
|
||||||
|
"""Forward pass correctness."""
|
||||||
|
|
||||||
|
def test_matches_reference(self, shapes, dtype):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=dtype)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
||||||
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
def test_output_shape(self, shapes):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == x.dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormRoPEBackward:
|
||||||
|
"""Backward pass correctness via gradient comparison."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D",
|
||||||
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||||
|
ids=["sliding_256", "global_512"],
|
||||||
|
)
|
||||||
|
def test_x_grad_matches_reference(self, B, S, H, D):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference backward
|
||||||
|
x_ref = torch.randn(
|
||||||
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||||
|
norm_ref.weight.data.copy_(weight_init)
|
||||||
|
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
|
||||||
|
y_ref.sum().backward()
|
||||||
|
|
||||||
|
# Fused backward
|
||||||
|
x_fused = x_ref.data.clone().requires_grad_(True)
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
||||||
|
y_fused.sum().backward()
|
||||||
|
|
||||||
|
cos_sim_x = torch.nn.functional.cosine_similarity(
|
||||||
|
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"B,S,H,D",
|
||||||
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
||||||
|
ids=["sliding_256", "global_512"],
|
||||||
|
)
|
||||||
|
def test_weight_grad_matches_reference(self, B, S, H, D):
|
||||||
|
from transformers.models.gemma4.modeling_gemma4 import (
|
||||||
|
Gemma4RMSNorm,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference
|
||||||
|
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
||||||
|
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
||||||
|
apply_rotary_pos_emb(
|
||||||
|
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
|
||||||
|
).sum().backward()
|
||||||
|
|
||||||
|
# Fused
|
||||||
|
w_fused = weight_init.clone().requires_grad_(True)
|
||||||
|
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
||||||
|
|
||||||
|
cos_sim_w = torch.nn.functional.cosine_similarity(
|
||||||
|
w_fused.grad.flatten().float(),
|
||||||
|
norm_ref.weight.grad.flatten().float(),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
assert cos_sim_w > 0.995, (
|
||||||
|
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_grad_flows(self):
|
||||||
|
"""Verify gradients are non-zero and finite."""
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
||||||
|
|
||||||
|
B, S, H, D = 1, 16, 4, 64
|
||||||
|
x = torch.randn(
|
||||||
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
||||||
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
|
||||||
|
y.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None, "x.grad is None"
|
||||||
|
assert w.grad is not None, "w.grad is None"
|
||||||
|
assert x.grad.isfinite().all(), "x.grad has non-finite values"
|
||||||
|
assert w.grad.isfinite().all(), "w.grad has non-finite values"
|
||||||
|
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
|
||||||
|
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusedRMSNormNoScale:
|
||||||
|
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
||||||
|
|
||||||
|
def test_forward_matches_reference(self, shapes, dtype):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||||
|
|
||||||
|
B, S, H, D = shapes
|
||||||
|
eps = 1e-6
|
||||||
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
y_ref = _reference_norm_noscale(x.clone(), eps)
|
||||||
|
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
||||||
|
)
|
||||||
|
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
|
||||||
|
|
||||||
|
def test_backward_flows(self):
|
||||||
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
||||||
|
|
||||||
|
x = torch.randn(
|
||||||
|
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||||
|
)
|
||||||
|
y = fused_rms_norm_noscale(x, eps=1e-6)
|
||||||
|
y.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.isfinite().all()
|
||||||
|
assert x.grad.abs().sum() > 0
|
||||||
Reference in New Issue
Block a user