[gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598)

This commit is contained in:
Wing Lian
2026-04-12 10:29:55 -04:00
committed by GitHub
parent e079cf16a2
commit b8358aa5ab
6 changed files with 993 additions and 1 deletions

View File

@@ -0,0 +1,529 @@
"""
Fused RMSNorm + RoPE Triton kernel for Gemma 4.
Fuses three operations into one kernel launch:
1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight
2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
3. (optional) RMSNorm without scale (for v_norm)
This eliminates two intermediate tensor materializations per Q/K path;
churn from rotate_half / apply_rotary_pos_emb.
Shapes:
X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim)
W: (head_dim,) — RMSNorm weight (None for with_scale=False)
cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast
sin: (rows, head_dim) — same as cos
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_rope_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm)
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
rotate_half swaps first/second halves and negates the first:
rotate_half([a, b]) = [-b, a]
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
# Normalize
X_norm = X_fp32 * rstd
# Apply weight if present (with_scale=True)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
# for col >= half_dim, take X_norm[col - half_dim]
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
tl.float32
)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
Y_row = X_norm * cos_row + X_rot_norm * sin_row
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_rope_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
COS_ptr,
COS_row_stride,
SIN_ptr,
SIN_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W))
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if HAS_WEIGHT:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
for row_idx in range(row_start, row_end):
cs_row_idx = row_idx // n_heads
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
# This is equivalent to rotating (dY * sin) because the rotation
# just permutes which elements are multiplied.
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
)
rot_mask = rot_offsets < n_cols
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_mask & mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
if HAS_WEIGHT:
dW_acc += dN * n
dm = dN * W_row
else:
dm = dN
# RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X)
dot_dm_x = tl.sum(dm * X_row, axis=0)
dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
if HAS_WEIGHT:
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
has_weight = W is not None
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_rope_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
W if has_weight else X, # dummy pointer when no weight
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
n_heads,
eps,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
n_rows, n_cols = dY.shape
has_weight = W is not None
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
rows_per_program = math.ceil(n_rows / sm_count)
dX = torch.empty_like(X)
if has_weight:
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device)
else:
_dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device)
_rms_norm_rope_backward_kernel[(sm_count,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W if has_weight else X, # dummy
cos,
cos.stride(0),
sin,
sin.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None
return dX, dW
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, head_dim) — broadcast across heads
sin: (B*S, head_dim) — broadcast across heads
n_heads: int
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
W,
cos,
sin,
eps,
n_heads,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, W, cos, sin, RSTD = ctx.saved_tensors
dX, dW = rms_norm_rope_backward(
dY,
X,
W,
cos,
sin,
RSTD,
ctx.n_heads,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
eps: float — RMSNorm epsilon
Returns:
y: (batch, seq_len, num_heads, head_dim) — normalized + rotated
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
# by dividing the row_idx by H to get the cos/sin row
cos_flat = cos.reshape(B * S, D).contiguous()
sin_flat = sin.reshape(B * S, D).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
return y_flat.view(shape)
@triton.jit
def _rms_norm_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""RMSNorm without scale weight: y = x / rms(x)"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
Y_row = X_fp32 * rstd
tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask)
@triton.jit
def _rms_norm_noscale_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
RSTD_ptr,
RSTD_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""Backward for y = x * rstd (no weight)."""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0
).to(tl.float32)
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
dot_dy_x = tl.sum(dY_row * X_row, axis=0)
dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask
)
class FusedRMSNormNoScaleFunction(torch.autograd.Function):
"""RMSNorm without learnable scale — used for Gemma4's v_norm."""
@staticmethod
@ensure_contiguous
def forward(ctx, X, eps):
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty_like(X)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
_rms_norm_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, RSTD)
ctx.n_cols = n_cols
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, RSTD = ctx.saved_tensors
n_rows = X.shape[0]
dX = torch.empty_like(X)
_rms_norm_noscale_backward_kernel[(n_rows,)](
dY,
dY.stride(0),
dX,
dX.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
RSTD,
RSTD.stride(0),
ctx.n_cols,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
return dX, None
def fused_rms_norm_noscale(x, eps=1e-6):
"""
RMSNorm without scale for v_norm.
Args:
x: (batch, seq_len, num_heads, head_dim)
Returns:
y: same shape, normalized
"""
shape = x.shape
x_flat = x.reshape(-1, shape[-1])
y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps)
return y_flat.view(shape)

View File

@@ -624,7 +624,14 @@ class ModelLoader:
def _set_attention_config(self):
"""Sample packing uses custom FA2 patch"""
if self.cfg.attn_implementation:
if self.cfg.gemma4_hybrid_attn_impl:
# Load model with flash_attention_2 for sliding window layers;
# global layers will be patched to sdpa post-load.
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
# Set flash_attention so multipack/sample_packing patches activate
self.cfg.flash_attention = True
elif self.cfg.attn_implementation:
self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation
elif self.cfg.flex_attention:
self.model_kwargs["attn_implementation"] = "flex_attention"

View File

@@ -156,6 +156,7 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -165,6 +166,72 @@ class PatchManager:
self._apply_lora_kernel_patch(model)
self._apply_scaling_softmax_patch(model)
def _apply_gemma_hybrid_attention(self, model: PreTrainedModel):
"""Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers.
Gemma 4 has global (full_attention) layers with head_dim=512
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
layers = None
config_source = None
for candidate in [model, getattr(model, "model", None)]:
if candidate is None:
continue
# Check direct layers
if hasattr(candidate, "layers"):
layers = candidate.layers
config_source = candidate
break
# Check language_model.layers (multimodal wrapper)
lang_model = getattr(candidate, "language_model", None)
if lang_model is not None and hasattr(lang_model, "layers"):
layers = lang_model.layers
config_source = lang_model
break
if layers is None:
LOG.warning(
"gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping"
)
return
config = getattr(config_source, "config", self.model_config)
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
LOG.warning(
"gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. "
"This feature requires a model with mixed sliding/global attention layers."
)
return
patched_count = 0
for layer_idx, layer in enumerate(layers):
if layer_types[layer_idx] != "sliding_attention":
# Global / full_attention layer - use SDPA instead of FA2
attn_module = getattr(layer, "self_attn", None)
if attn_module is not None and hasattr(attn_module, "config"):
sdpa_config = copy.copy(attn_module.config)
sdpa_config._attn_implementation = "sdpa"
attn_module.config = sdpa_config
patched_count += 1
LOG.info(
"gemma4_hybrid_attn_impl: patched %d global layers to use SDPA "
"(remaining %d sliding layers use flash_attention_2)",
patched_count,
len(layers) - patched_count,
)
def _apply_flash_attention_patches(self):
"""Apply patches related to Flash Attention."""
if self.cfg.xformers_attention and self.cfg.sample_packing:
@@ -324,6 +391,13 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
patch_gemma4_fused_attn()
@staticmethod
def _fix_nemotron_h_conversion_mapping():
"""Remove the spurious embedding→embeddings WeightRenaming from the

View File

@@ -0,0 +1,147 @@
"""
Gemma 4 fused attention monkeypatch.
Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton
kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb
Usage:
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
patch_gemma4_fused_attn()
"""
import logging
from typing import Callable
import torch
logger = logging.getLogger(__name__)
def _make_fused_forward(original_forward):
"""Create a patched forward that uses fused RMSNorm+RoPE kernels."""
from axolotl.kernels.gemma4_fused_rope import (
fused_rms_norm_noscale,
fused_rms_norm_rope,
)
def fused_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.gemma4.modeling_gemma4 import (
eager_attention_forward,
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
eps = self.config.rms_norm_eps
cos, sin = position_embeddings
# ---- Projections ----
# Use apply_qkv if present (LoRA kernel patch), otherwise direct proj
has_lora_qkv = hasattr(self, "apply_qkv")
if has_lora_qkv:
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape)
else:
query_states = self.q_proj(hidden_states).view(hidden_shape)
# ---- Q path: fused q_norm + RoPE ----
query_states = fused_rms_norm_rope(
query_states,
self.q_norm.weight,
cos,
sin,
eps=eps,
)
query_states = query_states.transpose(1, 2)
# ---- K/V path ----
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
if has_lora_qkv:
# apply_qkv already computed k/v projections
key_states = key_states.view(hidden_shape)
value_states = (
value_states.view(hidden_shape)
if self.v_proj is not None
else key_states
)
else:
key_states = self.k_proj(hidden_states).view(hidden_shape)
value_states = (
self.v_proj(hidden_states).view(hidden_shape)
if self.v_proj is not None
else key_states
)
# Fused k_norm + RoPE
key_states = fused_rms_norm_rope(
key_states,
self.k_norm.weight,
cos,
sin,
eps=eps,
)
key_states = key_states.transpose(1, 2)
# Fused v_norm (no scale, no RoPE)
value_states = fused_rms_norm_noscale(value_states, eps=eps)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return fused_forward
def patch_gemma4_fused_attn():
"""
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
"""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
original_forward = Gemma4TextAttention.forward
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
logger.info(
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
)

View File

@@ -777,6 +777,15 @@ class AxolotlInputConfig(
},
)
gemma4_hybrid_attn_impl: bool | None = Field(
default=None,
json_schema_extra={
"description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers "
"and sdpa for global (full_attention) layers. Global layers have head_dim=512 which "
"exceeds flash attention's supported size."
},
)
experts_implementation: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -0,0 +1,226 @@
"""
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
Tests forward and backward against the reference Gemma4 implementation
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
(head_dim=256) and global attention (head_dim=512) layer configurations.
"""
import pytest
import torch
torch.manual_seed(42)
# Skip entire module if no CUDA
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def _reference_norm_rope(x, weight, cos, sin, eps):
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
def _reference_norm_noscale(x, eps):
"""Reference: Gemma4RMSNorm with_scale=False."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
D = x.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
return norm(x)
@pytest.fixture(
params=[
(2, 64, 32, 256), # sliding window layer shape
(2, 64, 4, 512), # global attention layer shape
(1, 128, 16, 256), # different batch/seq
(1, 1, 1, 8), # minimal size
],
ids=["sliding_256", "global_512", "varied", "minimal"],
)
def shapes(request):
return request.param
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
def dtype(request):
return request.param
class TestFusedRMSNormRoPEForward:
"""Forward pass correctness."""
def test_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
weight = torch.randn(D, device="cuda", dtype=dtype)
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_output_shape(self, shapes):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = shapes
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
assert y.shape == x.shape
assert y.dtype == x.dtype
class TestFusedRMSNormRoPEBackward:
"""Backward pass correctness via gradient comparison."""
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_x_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference backward
x_ref = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight.data.copy_(weight_init)
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
y_ref.sum().backward()
# Fused backward
x_fused = x_ref.data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
@pytest.mark.parametrize(
"B,S,H,D",
[(2, 64, 32, 256), (2, 64, 4, 512)],
ids=["sliding_256", "global_512"],
)
def test_weight_grad_matches_reference(self, B, S, H, D):
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# Reference
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
apply_rotary_pos_emb(
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
).sum().backward()
# Fused
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
)
def test_grad_flows(self):
"""Verify gradients are non-zero and finite."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 16, 4, 64
x = torch.randn(
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
y.sum().backward()
assert x.grad is not None, "x.grad is None"
assert w.grad is not None, "w.grad is None"
assert x.grad.isfinite().all(), "x.grad has non-finite values"
assert w.grad.isfinite().all(), "w.grad has non-finite values"
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
class TestFusedRMSNormNoScale:
"""Tests for v_norm (RMSNorm without learnable scale)."""
def test_forward_matches_reference(self, shapes, dtype):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
B, S, H, D = shapes
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
y_ref = _reference_norm_noscale(x.clone(), eps)
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
def test_backward_flows(self):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
x = torch.randn(
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
y = fused_rms_norm_noscale(x, eps=1e-6)
y.sum().backward()
assert x.grad is not None
assert x.grad.isfinite().all()
assert x.grad.abs().sum() > 0