From b8358aa5abcb750b423457f67d374bc39d0a14a3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 12 Apr 2026 10:29:55 -0400 Subject: [PATCH] [gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels (#3598) --- src/axolotl/kernels/gemma4_fused_rope.py | 529 ++++++++++++++++++ src/axolotl/loaders/model.py | 9 +- src/axolotl/loaders/patch_manager.py | 74 +++ .../monkeypatch/models/gemma4/fused_attn.py | 147 +++++ src/axolotl/utils/schemas/config.py | 9 + tests/kernels/test_gemma4_fused_rope.py | 226 ++++++++ 6 files changed, 993 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/kernels/gemma4_fused_rope.py create mode 100644 src/axolotl/monkeypatch/models/gemma4/fused_attn.py create mode 100644 tests/kernels/test_gemma4_fused_rope.py diff --git a/src/axolotl/kernels/gemma4_fused_rope.py b/src/axolotl/kernels/gemma4_fused_rope.py new file mode 100644 index 000000000..f3b68e603 --- /dev/null +++ b/src/axolotl/kernels/gemma4_fused_rope.py @@ -0,0 +1,529 @@ +""" +Fused RMSNorm + RoPE Triton kernel for Gemma 4. + +Fuses three operations into one kernel launch: + 1. RMSNorm: x_norm = (x / sqrt(mean(x^2) + eps)) * weight + 2. RoPE: y = x_norm * cos + rotate_half(x_norm) * sin + 3. (optional) RMSNorm without scale (for v_norm) + +This eliminates two intermediate tensor materializations per Q/K path; +churn from rotate_half / apply_rotary_pos_emb. + +Shapes: + X: (rows, head_dim) — flattened from (batch, seq_len, num_heads, head_dim) + W: (head_dim,) — RMSNorm weight (None for with_scale=False) + cos: (rows, head_dim) — flattened from (batch, seq_len, 1, head_dim) after broadcast + sin: (rows, head_dim) — same as cos +""" + +import math +import operator + +import torch +import triton +import triton.language as tl +from liger_kernel.ops.utils import ( + calculate_settings, + compare_version, + ensure_contiguous, + torch_to_triton_dtype, +) +from liger_kernel.utils import is_npu_available + +if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): + try: + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _rms_norm_rope_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + W_ptr, + COS_ptr, + COS_row_stride, + SIN_ptr, + SIN_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + n_heads, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Fused forward: + x_norm = x / rms(x) [* weight] (RMSNorm) + y = x_norm * cos + rotate_half(x_norm) * sin (RoPE) + + rotate_half swaps first/second halves and negates the first: + rotate_half([a, b]) = [-b, a] + + cos/sin are indexed by row_idx // n_heads to handle per-head broadcast + (cos/sin have shape (B*S, D) while X has shape (B*S*H, D)). + """ + row_idx = tl.program_id(0).to(tl.int64) + # cos/sin row: divide by n_heads since cos/sin are (B*S, D) + cs_row_idx = row_idx // n_heads + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + half_dim = n_cols // 2 + + # Load input row + X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0) + X_dtype = X_row.dtype + X_fp32 = X_row.to(tl.float32) + + # RMSNorm: compute 1/rms + mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols + rstd = rsqrt(mean_sq + eps) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd) + + # Normalize + X_norm = X_fp32 * rstd + + # Apply weight if present (with_scale=True) + if HAS_WEIGHT: + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + X_norm = X_norm * W_row + + # RoPE: load cos/sin (broadcast across heads) + cos_row = tl.load( + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + sin_row = tl.load( + SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + + # rotate_half: for col < half_dim, take -X_norm[col + half_dim] + # for col >= half_dim, take X_norm[col - half_dim] + rot_offsets = tl.where( + col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + ) + rot_mask = rot_offsets < n_cols + X_rot = tl.load( + X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0 + ).to(tl.float32) + # Re-normalize the rotated values + X_rot_norm = X_rot * rstd + if HAS_WEIGHT: + W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to( + tl.float32 + ) + X_rot_norm = X_rot_norm * W_rot + + # Negate the first half (rotate_half negates x2, which becomes the first half) + sign = tl.where(col_offsets < half_dim, -1.0, 1.0) + X_rot_norm = X_rot_norm * sign + + # Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin + Y_row = X_norm * cos_row + X_rot_norm * sin_row + + tl.store( + Y_ptr + row_idx * Y_row_stride + col_offsets, + Y_row.to(X_dtype), + mask=mask, + ) + + +@triton.jit +def _rms_norm_rope_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + W_ptr, + COS_ptr, + COS_row_stride, + SIN_ptr, + SIN_row_stride, + RSTD_ptr, + RSTD_row_stride, + dW_ptr, + dW_row_stride, + n_rows, + n_cols, + n_heads, + rows_per_program, + HAS_WEIGHT: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Backward for Y = RoPE(RMSNorm(X, W)) + cos/sin indexed by row_idx // n_heads for per-head broadcast. + """ + row_block_id = tl.program_id(0).to(tl.int64) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + half_dim = n_cols // 2 + + dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + if HAS_WEIGHT: + W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32) + + for row_idx in range(row_start, row_end): + cs_row_idx = row_idx // n_heads + + dY_row = tl.load( + dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + X_row = tl.load( + X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride) + + cos_row = tl.load( + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + + # dN = dY * cos + rotate_half^T(dY * sin) + # rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half) + # + # Compute rotate_half_transpose(dY * sin) by loading dY and sin at + # rotated offsets directly: dY[rot] * sin[rot] * adj_sign + # This is equivalent to rotating (dY * sin) because the rotation + # just permutes which elements are multiplied. + rot_offsets = tl.where( + col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + ) + rot_mask = rot_offsets < n_cols + dY_rot = tl.load( + dY_ptr + row_idx * dY_row_stride + rot_offsets, + mask=rot_mask & mask, + other=0, + ).to(tl.float32) + sin_rot = tl.load( + SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets, + mask=rot_mask & mask, + other=0, + ).to(tl.float32) + + adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0) + dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign + + # Pre-weight normalized: n = rstd * x + n = X_row * rstd + + if HAS_WEIGHT: + dW_acc += dN * n + dm = dN * W_row + else: + dm = dN + + # RMSNorm backward: dX = rstd * (dm - (1/n_cols) * rstd^2 * dot(dm, X) * X) + dot_dm_x = tl.sum(dm * X_row, axis=0) + dX_row = rstd * (dm - (1.0 / n_cols) * rstd * rstd * dot_dm_x * X_row) + + tl.store( + dX_ptr + row_idx * dX_row_stride + col_offsets, + dX_row.to(X_dtype), + mask=mask, + ) + + if HAS_WEIGHT: + tl.store( + dW_ptr + row_block_id * dW_row_stride + col_offsets, + dW_acc, + mask=mask, + ) + + +def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): + """ + Args: + X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D) + W: (head_dim,) or None — RMSNorm weight + cos: (B*S, head_dim) — position embeddings (broadcast across heads) + sin: (B*S, head_dim) — position embeddings (broadcast across heads) + eps: float + n_heads: int — number of attention heads (for cos/sin indexing) + Returns: + Y, X_saved, RSTD, BLOCK_SIZE, num_warps + """ + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + has_weight = W is not None + + Y = torch.empty_like(X) + RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device) + + _rms_norm_rope_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + W if has_weight else X, # dummy pointer when no weight + cos, + cos.stride(0), + sin, + sin.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + n_heads, + eps, + HAS_WEIGHT=has_weight, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return Y, X, RSTD, BLOCK_SIZE, num_warps + + +def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps): + n_rows, n_cols = dY.shape + has_weight = W is not None + + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + rows_per_program = math.ceil(n_rows / sm_count) + + dX = torch.empty_like(X) + + if has_weight: + _dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=X.device) + else: + _dW = torch.empty((1, n_cols), dtype=torch.float32, device=X.device) + + _rms_norm_rope_backward_kernel[(sm_count,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + W if has_weight else X, # dummy + cos, + cos.stride(0), + sin, + sin.stride(0), + RSTD, + RSTD.stride(0), + _dW, + _dW.stride(0), + n_rows, + n_cols, + n_heads, + rows_per_program, + HAS_WEIGHT=has_weight, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + dW = _dW.sum(dim=0).to(W.dtype) if has_weight else None + return dX, dW + + +class FusedRMSNormRoPEFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, W, cos, sin, eps, n_heads): + """ + X: (B*S*H, head_dim) + W: (head_dim,) or None + cos: (B*S, head_dim) — broadcast across heads + sin: (B*S, head_dim) — broadcast across heads + n_heads: int + """ + Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward( + X, + W, + cos, + sin, + eps, + n_heads, + ) + ctx.eps = eps + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.n_heads = n_heads + ctx.has_weight = W is not None + ctx.save_for_backward(X_saved, W, cos, sin, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, cos, sin, RSTD = ctx.saved_tensors + dX, dW = rms_norm_rope_backward( + dY, + X, + W, + cos, + sin, + RSTD, + ctx.n_heads, + ctx.BLOCK_SIZE, + ctx.num_warps, + ) + return dX, dW, None, None, None, None + + +def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6): + """ + Apply fused RMSNorm + RoPE. + + Args: + x: (batch, seq_len, num_heads, head_dim) — after projection + view + weight: (head_dim,) — RMSNorm weight, or None for no-scale norm + cos: (batch, seq_len, head_dim) — from RotaryEmbedding + sin: (batch, seq_len, head_dim) — from RotaryEmbedding + eps: float — RMSNorm epsilon + + Returns: + y: (batch, seq_len, num_heads, head_dim) — normalized + rotated + """ + shape = x.shape # (B, S, H, D) + B, S, H, D = shape + # Flatten to 2D: (B*S*H, D) + x_flat = x.reshape(-1, D).contiguous() + # Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast + # by dividing the row_idx by H to get the cos/sin row + cos_flat = cos.reshape(B * S, D).contiguous() + sin_flat = sin.reshape(B * S, D).contiguous() + + y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H) + return y_flat.view(shape) + + +@triton.jit +def _rms_norm_forward_kernel( + Y_ptr, + Y_row_stride, + X_ptr, + X_row_stride, + RSTD_ptr, + RSTD_row_stride, + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """RMSNorm without scale weight: y = x / rms(x)""" + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0) + X_dtype = X_row.dtype + X_fp32 = X_row.to(tl.float32) + + mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols + rstd = rsqrt(mean_sq + eps) + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd) + + Y_row = X_fp32 * rstd + tl.store(Y_ptr + row_idx * Y_row_stride + col_offsets, Y_row.to(X_dtype), mask=mask) + + +@triton.jit +def _rms_norm_noscale_backward_kernel( + dY_ptr, + dY_row_stride, + dX_ptr, + dX_row_stride, + X_ptr, + X_row_stride, + X_dtype: tl.constexpr, + RSTD_ptr, + RSTD_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """Backward for y = x * rstd (no weight).""" + row_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + + dY_row = tl.load( + dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + X_row = tl.load( + X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0 + ).to(tl.float32) + rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride) + + dot_dy_x = tl.sum(dY_row * X_row, axis=0) + dX_row = rstd * (dY_row - (1.0 / n_cols) * rstd * rstd * dot_dy_x * X_row) + + tl.store( + dX_ptr + row_idx * dX_row_stride + col_offsets, dX_row.to(X_dtype), mask=mask + ) + + +class FusedRMSNormNoScaleFunction(torch.autograd.Function): + """RMSNorm without learnable scale — used for Gemma4's v_norm.""" + + @staticmethod + @ensure_contiguous + def forward(ctx, X, eps): + n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + Y = torch.empty_like(X) + RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device) + + _rms_norm_forward_kernel[(n_rows,)]( + Y, + Y.stride(0), + X, + X.stride(0), + RSTD, + RSTD.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps + ctx.save_for_backward(X, RSTD) + ctx.n_cols = n_cols + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, RSTD = ctx.saved_tensors + n_rows = X.shape[0] + dX = torch.empty_like(X) + _rms_norm_noscale_backward_kernel[(n_rows,)]( + dY, + dY.stride(0), + dX, + dX.stride(0), + X, + X.stride(0), + torch_to_triton_dtype[X.dtype], + RSTD, + RSTD.stride(0), + ctx.n_cols, + BLOCK_SIZE=ctx.BLOCK_SIZE, + num_warps=ctx.num_warps, + ) + return dX, None + + +def fused_rms_norm_noscale(x, eps=1e-6): + """ + RMSNorm without scale for v_norm. + + Args: + x: (batch, seq_len, num_heads, head_dim) + Returns: + y: same shape, normalized + """ + shape = x.shape + x_flat = x.reshape(-1, shape[-1]) + y_flat = FusedRMSNormNoScaleFunction.apply(x_flat, eps) + return y_flat.view(shape) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 3bfda7e23..83b6452dc 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -624,7 +624,14 @@ class ModelLoader: def _set_attention_config(self): """Sample packing uses custom FA2 patch""" - if self.cfg.attn_implementation: + if self.cfg.gemma4_hybrid_attn_impl: + # Load model with flash_attention_2 for sliding window layers; + # global layers will be patched to sdpa post-load. + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = "flash_attention_2" + # Set flash_attention so multipack/sample_packing patches activate + self.cfg.flash_attention = True + elif self.cfg.attn_implementation: self.model_kwargs["attn_implementation"] = self.cfg.attn_implementation elif self.cfg.flex_attention: self.model_kwargs["attn_implementation"] = "flex_attention" diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 018ca52a0..41fc35e6e 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -156,6 +156,7 @@ class PatchManager: # which would clobber any earlier fix. self._fix_nemotron_h_conversion_mapping() + self._apply_gemma_hybrid_attention(model) self._finalize_moe_expert_quantization(model) def apply_post_model_load_patches(self, model: PreTrainedModel): @@ -165,6 +166,72 @@ class PatchManager: self._apply_lora_kernel_patch(model) self._apply_scaling_softmax_patch(model) + def _apply_gemma_hybrid_attention(self, model: PreTrainedModel): + """Apply hybrid attention: FA2 for sliding window layers, SDPA for global layers. + + Gemma 4 has global (full_attention) layers with head_dim=512 + which exceeds flash attention's supported size. This patch loads the model + with flash_attention_2 for the sliding window layers (head_dim=256), then + gives each global layer a shallow-copied config with _attn_implementation="sdpa". + """ + if not self.cfg.gemma4_hybrid_attn_impl: + return + + import copy + + # Navigate to the module that has 'layers' - varies by model structure: + # Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers + # Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers + layers = None + config_source = None + for candidate in [model, getattr(model, "model", None)]: + if candidate is None: + continue + # Check direct layers + if hasattr(candidate, "layers"): + layers = candidate.layers + config_source = candidate + break + # Check language_model.layers (multimodal wrapper) + lang_model = getattr(candidate, "language_model", None) + if lang_model is not None and hasattr(lang_model, "layers"): + layers = lang_model.layers + config_source = lang_model + break + + if layers is None: + LOG.warning( + "gemma4_hybrid_attn_impl: could not find decoder layers in model, skipping" + ) + return + + config = getattr(config_source, "config", self.model_config) + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + LOG.warning( + "gemma4_hybrid_attn_impl: model config has no 'layer_types', skipping. " + "This feature requires a model with mixed sliding/global attention layers." + ) + return + + patched_count = 0 + for layer_idx, layer in enumerate(layers): + if layer_types[layer_idx] != "sliding_attention": + # Global / full_attention layer - use SDPA instead of FA2 + attn_module = getattr(layer, "self_attn", None) + if attn_module is not None and hasattr(attn_module, "config"): + sdpa_config = copy.copy(attn_module.config) + sdpa_config._attn_implementation = "sdpa" + attn_module.config = sdpa_config + patched_count += 1 + + LOG.info( + "gemma4_hybrid_attn_impl: patched %d global layers to use SDPA " + "(remaining %d sliding layers use flash_attention_2)", + patched_count, + len(layers) - patched_count, + ) + def _apply_flash_attention_patches(self): """Apply patches related to Flash Attention.""" if self.cfg.xformers_attention and self.cfg.sample_packing: @@ -324,6 +391,13 @@ class PatchManager: patch_qwen3_5_vlm_flash_attention() + if self.cfg.model_config_type in ("gemma4", "gemma4_text"): + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + patch_gemma4_fused_attn() + @staticmethod def _fix_nemotron_h_conversion_mapping(): """Remove the spurious embedding→embeddings WeightRenaming from the diff --git a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py new file mode 100644 index 000000000..7cb5c6beb --- /dev/null +++ b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py @@ -0,0 +1,147 @@ +""" +Gemma 4 fused attention monkeypatch. + +Replaces the per-layer RMSNorm + RoPE + transpose sequence with fused Triton +kernels, eliminating intermediate tensor allocations from rotate_half / apply_rotary_pos_emb + +Usage: + from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn + patch_gemma4_fused_attn() +""" + +import logging +from typing import Callable + +import torch + +logger = logging.getLogger(__name__) + + +def _make_fused_forward(original_forward): + """Create a patched forward that uses fused RMSNorm+RoPE kernels.""" + + from axolotl.kernels.gemma4_fused_rope import ( + fused_rms_norm_noscale, + fused_rms_norm_rope, + ) + + def fused_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: torch.Tensor | None, + shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], + past_key_values=None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + from transformers.models.gemma4.modeling_gemma4 import ( + eager_attention_forward, + ) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + eps = self.config.rms_norm_eps + + cos, sin = position_embeddings + + # ---- Projections ---- + # Use apply_qkv if present (LoRA kernel patch), otherwise direct proj + has_lora_qkv = hasattr(self, "apply_qkv") + + if has_lora_qkv: + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = query_states.view(hidden_shape) + else: + query_states = self.q_proj(hidden_states).view(hidden_shape) + + # ---- Q path: fused q_norm + RoPE ---- + query_states = fused_rms_norm_rope( + query_states, + self.q_norm.weight, + cos, + sin, + eps=eps, + ) + query_states = query_states.transpose(1, 2) + + # ---- K/V path ---- + if self.is_kv_shared_layer: + key_states, value_states = shared_kv_states[self.kv_shared_layer_index] + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + if has_lora_qkv: + # apply_qkv already computed k/v projections + key_states = key_states.view(hidden_shape) + value_states = ( + value_states.view(hidden_shape) + if self.v_proj is not None + else key_states + ) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = ( + self.v_proj(hidden_states).view(hidden_shape) + if self.v_proj is not None + else key_states + ) + + # Fused k_norm + RoPE + key_states = fused_rms_norm_rope( + key_states, + self.k_norm.weight, + cos, + sin, + eps=eps, + ) + key_states = key_states.transpose(1, 2) + + # Fused v_norm (no scale, no RoPE) + value_states = fused_rms_norm_noscale(value_states, eps=eps) + value_states = value_states.transpose(1, 2) + + if past_key_values is not None and not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx + ) + if self.store_full_length_kv: + shared_kv_states[self.layer_idx] = key_states, value_states + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + return fused_forward + + +def patch_gemma4_fused_attn(): + """ + Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels. + """ + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention + + original_forward = Gemma4TextAttention.forward + Gemma4TextAttention.forward = _make_fused_forward(original_forward) + + logger.info( + "Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels" + ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 474c3a349..496657030 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -777,6 +777,15 @@ class AxolotlInputConfig( }, ) + gemma4_hybrid_attn_impl: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Use hybrid attention for Gemma 4: flash_attention_2 for sliding window layers " + "and sdpa for global (full_attention) layers. Global layers have head_dim=512 which " + "exceeds flash attention's supported size." + }, + ) + experts_implementation: str | None = Field( default=None, json_schema_extra={ diff --git a/tests/kernels/test_gemma4_fused_rope.py b/tests/kernels/test_gemma4_fused_rope.py new file mode 100644 index 000000000..7daedd612 --- /dev/null +++ b/tests/kernels/test_gemma4_fused_rope.py @@ -0,0 +1,226 @@ +""" +Correctness tests for the fused RMSNorm+RoPE Triton kernel. + +Tests forward and backward against the reference Gemma4 implementation +(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window +(head_dim=256) and global attention (head_dim=512) layer configurations. +""" + +import pytest +import torch + +torch.manual_seed(42) + +# Skip entire module if no CUDA +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _reference_norm_rope(x, weight, cos, sin, eps): + """Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb.""" + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4RMSNorm, + apply_rotary_pos_emb, + ) + + D = x.shape[-1] + norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype) + norm.weight.data.copy_(weight) + normed = norm(x) + return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2) + + +def _reference_norm_noscale(x, eps): + """Reference: Gemma4RMSNorm with_scale=False.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm + + D = x.shape[-1] + norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype) + return norm(x) + + +@pytest.fixture( + params=[ + (2, 64, 32, 256), # sliding window layer shape + (2, 64, 4, 512), # global attention layer shape + (1, 128, 16, 256), # different batch/seq + (1, 1, 1, 8), # minimal size + ], + ids=["sliding_256", "global_512", "varied", "minimal"], +) +def shapes(request): + return request.param + + +@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def dtype(request): + return request.param + + +class TestFusedRMSNormRoPEForward: + """Forward pass correctness.""" + + def test_matches_reference(self, shapes, dtype): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = shapes + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + weight = torch.randn(D, device="cuda", dtype=dtype) + cos = torch.randn(B, S, D, device="cuda", dtype=dtype) + sin = torch.randn(B, S, D, device="cuda", dtype=dtype) + + y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps) + y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps) + + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999" + + def test_output_shape(self, shapes): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = shapes + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(D, device="cuda", dtype=torch.bfloat16) + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + + y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6) + assert y.shape == x.shape + assert y.dtype == x.dtype + + +class TestFusedRMSNormRoPEBackward: + """Backward pass correctness via gradient comparison.""" + + @pytest.mark.parametrize( + "B,S,H,D", + [(2, 64, 32, 256), (2, 64, 4, 512)], + ids=["sliding_256", "global_512"], + ) + def test_x_grad_matches_reference(self, B, S, H, D): + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4RMSNorm, + apply_rotary_pos_emb, + ) + + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + + # Reference backward + x_ref = torch.randn( + B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16) + norm_ref.weight.data.copy_(weight_init) + y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2) + y_ref.sum().backward() + + # Fused backward + x_fused = x_ref.data.clone().requires_grad_(True) + w_fused = weight_init.clone().requires_grad_(True) + y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps) + y_fused.sum().backward() + + cos_sim_x = torch.nn.functional.cosine_similarity( + x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0 + ) + assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999" + + @pytest.mark.parametrize( + "B,S,H,D", + [(2, 64, 32, 256), (2, 64, 4, 512)], + ids=["sliding_256", "global_512"], + ) + def test_weight_grad_matches_reference(self, B, S, H, D): + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4RMSNorm, + apply_rotary_pos_emb, + ) + + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + + # Reference + x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16) + norm_ref.weight = torch.nn.Parameter(weight_init.clone()) + apply_rotary_pos_emb( + norm_ref(x_ref), cos, sin, unsqueeze_dim=2 + ).sum().backward() + + # Fused + w_fused = weight_init.clone().requires_grad_(True) + fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward() + + cos_sim_w = torch.nn.functional.cosine_similarity( + w_fused.grad.flatten().float(), + norm_ref.weight.grad.flatten().float(), + dim=0, + ) + assert cos_sim_w > 0.995, ( + f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995" + ) + + def test_grad_flows(self): + """Verify gradients are non-zero and finite.""" + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = 1, 16, 4, 64 + x = torch.randn( + B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True) + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + + y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6) + y.sum().backward() + + assert x.grad is not None, "x.grad is None" + assert w.grad is not None, "w.grad is None" + assert x.grad.isfinite().all(), "x.grad has non-finite values" + assert w.grad.isfinite().all(), "w.grad has non-finite values" + assert x.grad.abs().sum() > 0, "x.grad is all zeros" + assert w.grad.abs().sum() > 0, "w.grad is all zeros" + + +class TestFusedRMSNormNoScale: + """Tests for v_norm (RMSNorm without learnable scale).""" + + def test_forward_matches_reference(self, shapes, dtype): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale + + B, S, H, D = shapes + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=dtype) + + y_ref = _reference_norm_noscale(x.clone(), eps) + y_fused = fused_rms_norm_noscale(x.clone(), eps=eps) + + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999" + + def test_backward_flows(self): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale + + x = torch.randn( + 1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + y = fused_rms_norm_noscale(x, eps=1e-6) + y.sum().backward() + + assert x.grad is not None + assert x.grad.isfinite().all() + assert x.grad.abs().sum() > 0