[gemma4] fix fused RMSNorm+RoPE on hybrid attention models

- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1].
  Triton forward/backward take an n_rot runtime arg that restricts
  rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only
  pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin
  that broadcast over batch.

- Forward: _make_fused_forward used a stale shared_kv_states kwarg the
  current decoder layer no longer passes. Now mirrors stock attention,
  reading/writing past_key_values.shared_layers.
This commit is contained in:
Wing Lian
2026-04-15 12:59:00 +00:00
parent d4e9cf2eec
commit dc16859983
9 changed files with 1217 additions and 86 deletions

View File

@@ -242,6 +242,87 @@ class ProducerConfig:
)
class _GroupShardedSampler:
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
``RepeatSampler`` yields ``num_generations`` consecutive copies of
each prompt, forming a GRPO group. For distributed training each
rank must see a disjoint slice of prompts (otherwise every rank
dogpiles on the first 1/world_size of the batch) while keeping each
group intact on a single rank so advantage normalization sees all
peer generations.
``accelerator.prepare(DataLoader)`` does not handle this correctly
for custom samplers with ``split_batches=False`` (the default): it
leaves the sampler alone and every rank replays identical indices.
This wrapper fixes that by consuming the inner sampler's full
output, chunking it into ``num_generations``-sized groups, and
round-robining whole groups across ranks.
Intended to be used ONLY when distributed training is active
(``num_replicas > 1``); for single-rank it is a no-op but still
correct.
"""
def __init__(
self,
inner: Any,
num_generations: int,
rank: int,
num_replicas: int,
):
if num_generations < 1:
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if not (0 <= rank < num_replicas):
raise ValueError(
f"rank must be in [0, {num_replicas}), got {rank}"
)
self.inner = inner
self.num_generations = num_generations
self.rank = rank
self.num_replicas = num_replicas
def __iter__(self):
all_indices = list(self.inner)
if len(all_indices) % self.num_generations != 0:
raise ValueError(
f"inner sampler yielded {len(all_indices)} indices, "
f"not a multiple of num_generations={self.num_generations}"
)
# Chunk the flat index sequence into groups of num_generations
# consecutive indices. ``RepeatSampler`` guarantees that each
# group contains num_generations copies of the same prompt id.
groups = [
all_indices[i : i + self.num_generations]
for i in range(0, len(all_indices), self.num_generations)
]
# Round-robin whole groups across ranks. Round-robin (vs.
# contiguous chunking) preserves approximate shuffled order on
# each rank even when the group count is small relative to the
# world size.
for group in groups[self.rank :: self.num_replicas]:
yield from group
def __len__(self):
try:
inner_len = len(self.inner)
except TypeError:
# Non-sized inner sampler — we can't know the per-rank
# length without materializing. Return 0 as a hint that the
# DataLoader should fall back to iteration.
return 0
total_groups = inner_len // self.num_generations
# Ceiling division for the trailing groups that don't divide
# evenly — extra groups go to the first ``total_groups %
# num_replicas`` ranks, matching the round-robin above.
my_groups = (
total_groups + self.num_replicas - self.rank - 1
) // self.num_replicas
return my_groups * self.num_generations
class DataProducer(ABC):
"""Abstract base class for online data producers.
@@ -556,6 +637,34 @@ class GRPODataProducer(BaseDataProducer):
seed=self._seed,
)
# Shard the sampler across distributed ranks so each rank sees
# a disjoint slice of prompts. ``RepeatSampler`` groups each
# prompt with ``num_generations`` consecutive copies — our
# wrapper round-robins WHOLE groups across ranks so all
# generations of a given prompt stay on the same rank (needed
# for GRPO advantage normalization within a group).
#
# Without this, ``accelerator.prepare(dl)`` with the default
# ``split_batches=False`` leaves the custom sampler alone, so
# every rank iterates the identical index sequence and the
# cluster dogpiles on the first 1/world_size of the prompts.
num_replicas = max(1, trainer.accelerator.num_processes)
if num_replicas > 1:
sampler = _GroupShardedSampler(
inner=sampler,
num_generations=self._num_generations,
rank=trainer.accelerator.process_index,
num_replicas=num_replicas,
)
logger.info(
"[RANK:%d] _GroupShardedSampler active "
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
trainer.accelerator.process_index,
num_replicas,
self._num_generations,
self._generation_batch_size,
)
# Use identity collator (same as stock GRPOTrainer)
def _identity(x):
return x
@@ -574,12 +683,11 @@ class GRPODataProducer(BaseDataProducer):
rank=trainer.args.process_index,
),
)
self._prompt_dl = trainer.accelerator.prepare(dl)
# Don't let accelerator track this dataloader
acc_dls = trainer.accelerator._dataloaders
if self._prompt_dl in acc_dls:
acc_dls.remove(self._prompt_dl)
# Skip accelerator.prepare — we're handling per-rank sharding
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
# otherwise try to wrap the DataLoader with its own sharding
# logic which does not understand our group structure.
self._prompt_dl = dl
self._prompt_iter = iter(self._prompt_dl)

View File

@@ -110,11 +110,35 @@ class NemoGymDataProducer(GRPODataProducer):
item["agent_ref"] = full_item["agent_ref"]
dataset_items.append(item)
# Expand by num_generations (agent produces one rollout per call)
expanded_items = []
for item in dataset_items:
for _ in range(self._num_generations):
expanded_items.append(item)
# NOTE: do NOT re-expand by num_generations here.
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
# yields ``num_generations`` consecutive copies of each unique
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
# num_generations)`` items — one entry per rollout. Expanding
# again here would fire ``num_generations^2`` rollouts per
# prompt per rank and make every step dogpile on a handful of
# tasks.
expanded_items = dataset_items
# Diagnostic: log what this rank is about to fire.
try:
import collections
iid_counts = collections.Counter()
for it in dataset_items:
iid_counts[
(it.get("responses_create_params", {}).get("metadata") or {}).get(
"instance_id"
)
] += 1
LOG.info(
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
trainer.accelerator.process_index,
len(dataset_items),
len(iid_counts),
list(iid_counts.most_common(5)),
)
except Exception:
pass
# Call NeMo Gym agents
loop = asyncio.new_event_loop()
@@ -140,6 +164,7 @@ class NemoGymDataProducer(GRPODataProducer):
logprobs_list = []
rewards_list = []
num_turns_list: list[int] = []
for resp in responses:
parsed = _parse_agent_response(resp, eos_token_id)
prompt_ids_list.append(parsed["prompt_ids"])
@@ -147,6 +172,7 @@ class NemoGymDataProducer(GRPODataProducer):
env_mask_list.append(parsed["env_mask"])
logprobs_list.append(parsed["logprobs"])
rewards_list.append(parsed["reward"])
num_turns_list.append(parsed.get("num_turns", 0))
# Pad to tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@@ -179,22 +205,49 @@ class NemoGymDataProducer(GRPODataProducer):
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
# Inject rewards into inputs so _compute_deferred_scores can use them
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
# Our passthrough reward_fn reads "env_reward" from kwargs.
# Inject per-rollout reward + num_turns into each input. Since
# ``RepeatSampler`` already yields ``num_generations`` copies of
# each prompt, ``inputs`` has ONE entry per rollout (matching
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
# GRPO advantage normalization is the trainer's job downstream.
assert len(inputs) == len(rewards_list), (
f"rewards/inputs length mismatch: "
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
)
for i, inp in enumerate(inputs):
# Each input gets rewards for its num_generations rollouts
start = i * self._num_generations
end = start + self._num_generations
inp["env_reward"] = rewards_list[start:end]
inp["env_reward"] = rewards_list[i]
inp["num_turns"] = num_turns_list[i]
# One expanded_input per rollout (already correct count because
# inputs has num_generations copies baked in by the sampler).
expanded_inputs = [dict(inp) for inp in inputs]
# Log rollout-level stats to wandb from rank 0. These are the
# true agent-side metrics (not the tokenized TRL view) — so
# num_turns reflects how many /run iterations each rollout
# actually took before finishing or hitting max_turns.
if is_main and num_turns_list:
try:
import wandb
if wandb.run is not None:
import statistics as _stats
nonzero = sum(1 for r in rewards_list if r > 0)
log_payload = {
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
"rollout/num_turns/min": float(min(num_turns_list)),
"rollout/num_turns/max": float(max(num_turns_list)),
"rollout/reward/mean": float(_stats.mean(rewards_list)),
"rollout/reward/nonzero_frac": (
nonzero / len(rewards_list) if rewards_list else 0.0
),
"rollout/n_samples": float(len(rewards_list)),
}
wandb.log(log_payload, commit=False)
except Exception as exc: # never let metric logging break training
LOG.warning("rollout wandb log failed: %s", exc)
# Expand inputs to match expanded rollouts (num_generations copies)
expanded_inputs = []
for inp in inputs:
for g in range(self._num_generations):
expanded_inp = dict(inp)
expanded_inp["env_reward"] = inp["env_reward"][g]
expanded_inputs.append(expanded_inp)
# Decode completions for reward functions
completions = trainer.processing_class.batch_decode(

View File

@@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel(
RSTD_ptr,
RSTD_row_stride,
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT: tl.constexpr,
@@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel(
):
"""
Fused forward:
x_norm = x / rms(x) [* weight] (RMSNorm)
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
y[..., :n_rot] = rope(x_norm[..., :n_rot])
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
rotate_half swaps first/second halves and negates the first:
rotate_half([a, b]) = [-b, a]
rotate_half swaps first/second halves and negates the first, restricted
to the rotary span [0, n_rot):
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
For the partial-rotary pass-through region we load cos with default 1.0
and sin with default 0.0 outside [0, n_rot), so the same formula
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
"""
row_idx = tl.program_id(0).to(tl.int64)
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
cs_row_idx = row_idx // n_heads
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
# Load input row
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
X_dtype = X_row.dtype
X_fp32 = X_row.to(tl.float32)
# RMSNorm: compute 1/rms
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
@@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel(
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
X_norm = X_norm * W_row
# RoPE: load cos/sin (broadcast across heads)
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
# cos=1, sin=0 so the formula leaves X_norm untouched.
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
).to(tl.float32)
sin_row = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
mask=rot_mask_col,
other=0.0,
).to(tl.float32)
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
# for col >= half_dim, take X_norm[col - half_dim]
# rotate_half within [0, n_rot):
# for col < half_rot: take -X_norm[col + half_rot]
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
)
rot_mask = rot_offsets < n_cols
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
X_rot = tl.load(
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
).to(tl.float32)
# Re-normalize the rotated values
X_rot_norm = X_rot * rstd
if HAS_WEIGHT:
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
tl.float32
)
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
X_rot_norm = X_rot_norm * W_rot
# Negate the first half (rotate_half negates x2, which becomes the first half)
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
X_rot_norm = X_rot_norm * sign
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
@@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel(
dW_row_stride,
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = RoPE(RMSNorm(X, W))
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
(`n_rot <= n_cols`).
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
output is just the normalized row, so dN[col] = dY[col] (achieved by
loading cos with default 1.0 and forcing the rotate-half contribution
to zero outside the rotary span).
cos/sin indexed by row_idx // n_heads for per-head broadcast.
"""
row_block_id = tl.program_id(0).to(tl.int64)
@@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel(
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
half_dim = n_cols // 2
rot_mask_col = col_offsets < n_rot
half_rot = n_rot // 2
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
@@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel(
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
cos_row = tl.load(
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
mask=rot_mask_col,
other=1.0,
).to(tl.float32)
# dN = dY * cos + rotate_half^T(dY * sin)
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
#
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
# This is equivalent to rotating (dY * sin) because the rotation
# just permutes which elements are multiplied.
# For col >= n_rot the formula must collapse to dN = dY (since the
# forward is just a pass-through). cos defaults to 1.0 above; the
# rotate-half contribution is masked to zero below.
rot_offsets = tl.where(
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
)
rot_mask = rot_offsets < n_cols
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
dY_rot = tl.load(
dY_ptr + row_idx * dY_row_stride + rot_offsets,
mask=rot_mask & mask,
mask=rot_load_mask,
other=0,
).to(tl.float32)
sin_rot = tl.load(
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
mask=rot_mask & mask,
mask=rot_load_mask,
other=0,
).to(tl.float32)
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
rotate_term = dY_rot * sin_rot * adj_sign
# Zero out rotate-half contribution outside the rotary span.
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
dN = dY_row * cos_row + rotate_term
# Pre-weight normalized: n = rstd * x
n = X_row * rstd
@@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel(
)
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
"""
Args:
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
W: (head_dim,) or None — RMSNorm weight
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
eps: float
n_heads: int — number of attention heads (for cos/sin indexing)
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
partial rotary). Must be even and ``<= head_dim``.
Returns:
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
"""
@@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
RSTD,
RSTD.stride(0),
n_cols,
n_rot,
n_heads,
eps,
HAS_WEIGHT=has_weight,
@@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
return Y, X, RSTD, BLOCK_SIZE, num_warps
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
def rms_norm_rope_backward(
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
):
n_rows, n_cols = dY.shape
has_weight = W is not None
@@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
_dW.stride(0),
n_rows,
n_cols,
n_rot,
n_heads,
rows_per_program,
HAS_WEIGHT=has_weight,
@@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
class FusedRMSNormRoPEFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, W, cos, sin, eps, n_heads):
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
"""
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, head_dim) — broadcast across heads
sin: (B*S, head_dim) — broadcast across heads
X: (B*S*H, head_dim)
W: (head_dim,) or None
cos: (B*S, n_rot) — broadcast across heads
sin: (B*S, n_rot) — broadcast across heads
n_heads: int
n_rot: int — rotary dim (<= head_dim)
"""
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
X,
@@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
sin,
eps,
n_heads,
n_rot,
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.n_heads = n_heads
ctx.n_rot = n_rot
ctx.has_weight = W is not None
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
return Y
@@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
sin,
RSTD,
ctx.n_heads,
ctx.n_rot,
ctx.BLOCK_SIZE,
ctx.num_warps,
)
return dX, dW, None, None, None, None
return dX, dW, None, None, None, None, None
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
Apply fused RMSNorm + RoPE.
Apply fused RMSNorm + (partial) RoPE.
Args:
x: (batch, seq_len, num_heads, head_dim) — after projection + view
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
must be even and ``<= head_dim``. When ``n_rot < head_dim``
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
(partial-rotary pass-through), matching stock Gemma 4 with
``partial_rotary_factor < 1.0``.
sin: (batch, seq_len, n_rot) — same shape as ``cos``
eps: float — RMSNorm epsilon
Returns:
@@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
"""
shape = x.shape # (B, S, H, D)
B, S, H, D = shape
n_rot = cos.shape[-1]
if sin.shape[-1] != n_rot:
raise ValueError(
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
f"sin={sin.shape[-1]}"
)
if n_rot > D:
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
if n_rot % 2 != 0:
raise ValueError(f"rotary dim must be even, got {n_rot}")
# Flatten to 2D: (B*S*H, D)
x_flat = x.reshape(-1, D).contiguous()
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
# by dividing the row_idx by H to get the cos/sin row
cos_flat = cos.reshape(B * S, D).contiguous()
sin_flat = sin.reshape(B * S, D).contiguous()
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
# all sequences share the same rotary positions). The kernel needs a
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
# onto a single (b, s) pair, so expand-then-contiguous to materialize
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
if cos.shape[0] != B:
if cos.shape[0] != 1:
raise ValueError(
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
f"to x batch dim ({B})"
)
cos = cos.expand(B, S, n_rot)
sin = sin.expand(B, S, n_rot)
cos_flat = cos.reshape(B * S, n_rot).contiguous()
sin_flat = sin.reshape(B * S, n_rot).contiguous()
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
y_flat = FusedRMSNormRoPEFunction.apply(
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
)
return y_flat.view(shape)

View File

@@ -156,6 +156,14 @@ class PatchManager:
# which would clobber any earlier fix.
self._fix_nemotron_h_conversion_mapping()
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
# override needs to walk the raw model tree, which is broken by
# the post-load PEFT wrapping. The accompanying
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
# installation-time-independent, so both halves of the fix live
# cleanly in the same call even though one is instance-scoped
# and the other is module-scoped.
self._apply_gemma_hybrid_attention(model)
self._finalize_moe_expert_quantization(model)
@@ -173,12 +181,23 @@ class PatchManager:
which exceeds flash attention's supported size. This patch loads the model
with flash_attention_2 for the sliding window layers (head_dim=256), then
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
which fixes the corresponding mask construction inside
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
override is not enough — the forward still builds a 2D FA2-format mask
at the model level and the SDPA layers crash at long context lengths
with ``RuntimeError: The expanded size of the tensor ... must match``.
"""
if not self.cfg.gemma4_hybrid_attn_impl:
return
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
# Navigate to the module that has 'layers' - varies by model structure:
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
@@ -392,6 +411,14 @@ class PatchManager:
patch_qwen3_5_vlm_flash_attention()
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
# The fused attn path is now compatible with
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
# rotary (cos.shape[-1] < head_dim) and the fused forward
# mirrors the current ``Gemma4TextAttention.forward`` API
# for shared kv (read from / write to
# ``past_key_values.shared_layers``). See
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
# for the history.
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)

View File

@@ -0,0 +1,115 @@
"""Hybrid attention mask fix for Gemma 4.
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
``self_attn.config``, leaving sliding-window layers on FA2.
The per-layer config override alone is insufficient, however:
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
using the **model-level** config and passes the mapped mask to each
decoder layer. With FA2 still set at the model level, the ``full_attention``
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
The global layers then fail with::
RuntimeError: The expanded size of the tensor (S) must match the existing
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
sizes: [B, S]
...when the sequence length grows past roughly 7k tokens.
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
the original in ``masking_utils``. The wrapper forces
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
alone, so sliding-window layers continue to receive FA2-format masks.
The patch is idempotent. Install once per process, before any Gemma 4
forward pass runs.
"""
from __future__ import annotations
import copy
from typing import Any
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
_PATCH_APPLIED = False
def patch_gemma4_hybrid_mask() -> bool:
"""Install the Gemma 4 hybrid-attention mask fix.
Returns ``True`` if the patch was installed (or was already installed),
``False`` if the target module could not be imported (e.g. transformers
version predates Gemma 4) — in which case nothing is done and the
caller can continue unaffected.
"""
global _PATCH_APPLIED
if _PATCH_APPLIED:
return True
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
LOG.debug(
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
"skipping. This is fine for non-Gemma4 training."
)
return False
if not hasattr(modeling_gemma4, "create_causal_mask"):
LOG.warning(
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
"binding, skipping. Transformers API may have changed."
)
return False
original = modeling_gemma4.create_causal_mask
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
"""Wrapper that forces SDPA format for the full-attention mask.
The global layers were patched to SDPA by
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
original ``create_causal_mask`` dispatches on
``config._attn_implementation``; we shadow that with a local
override.
"""
sdpa_config = copy.copy(config)
sdpa_config._attn_implementation = "sdpa"
return original(sdpa_config, *args, **kwargs)
# Preserve the original reference on the wrapper for tests / teardown.
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
_PATCH_APPLIED = True
LOG.info(
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
"force SDPA-format masks for full-attention layers"
)
return True
def unpatch_gemma4_hybrid_mask() -> None:
"""Restore the original ``create_causal_mask``. Useful for tests."""
global _PATCH_APPLIED
if not _PATCH_APPLIED:
return
try:
from transformers.models.gemma4 import modeling_gemma4
except ImportError:
_PATCH_APPLIED = False
return
current = modeling_gemma4.create_causal_mask
original = getattr(current, "_axolotl_original", None)
if original is not None:
modeling_gemma4.create_causal_mask = original
_PATCH_APPLIED = False

View File

@@ -30,7 +30,6 @@ def _make_fused_forward(original_forward):
hidden_states: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: torch.Tensor | None,
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
past_key_values=None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None]:
@@ -66,8 +65,14 @@ def _make_fused_forward(original_forward):
query_states = query_states.transpose(1, 2)
# ---- K/V path ----
if self.is_kv_shared_layer:
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
# Current transformers stores shared kv on `past_key_values.shared_layers`
# (the legacy `shared_kv_states` decoder kwarg was removed). We mirror
# the stock attention forward exactly so the dispatch is identical
# regardless of whether the model was patched.
if self.is_kv_shared_layer and past_key_values is not None:
key_states, value_states = past_key_values.shared_layers[
self.kv_shared_layer_index
]
key_states = key_states.to(query_states.device)
value_states = value_states.to(query_states.device)
else:
@@ -101,12 +106,18 @@ def _make_fused_forward(original_forward):
value_states = fused_rms_norm_noscale(value_states, eps=eps)
value_states = value_states.transpose(1, 2)
if past_key_values is not None and not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
shared_kv_states[self.layer_idx] = key_states, value_states
if past_key_values is not None:
if not self.is_kv_shared_layer:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx
)
if self.store_full_length_kv:
if not hasattr(past_key_values, "shared_layers"):
past_key_values.shared_layers = {}
past_key_values.shared_layers[self.layer_idx] = (
key_states,
value_states,
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":

View File

@@ -38,6 +38,30 @@ def _reference_norm_noscale(x, eps):
return norm(x)
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
the trailing columns passed through unchanged. Mirrors how Llama-style
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
"""
from transformers.models.gemma4.modeling_gemma4 import (
Gemma4RMSNorm,
apply_rotary_pos_emb,
)
D = x.shape[-1]
n_rot = cos.shape[-1]
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
norm.weight.data.copy_(weight)
normed = norm(x)
if n_rot == D:
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
x_rot = normed[..., :n_rot]
x_pass = normed[..., n_rot:]
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
return torch.cat([rotated, x_pass], dim=-1)
@pytest.fixture(
params=[
(2, 64, 32, 256), # sliding window layer shape
@@ -194,6 +218,172 @@ class TestFusedRMSNormRoPEBackward:
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
class TestFusedRMSNormRoPEPartialRotary:
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
Compares against the original primitives (`Gemma4RMSNorm` +
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
columns passed through. Without the kernel fix this used to crash with
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
"""
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
(2, 16, 4, 64, 16), # quarter rotary
(2, 32, 8, 128, 64), # half rotary, larger heads
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
],
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
)
def test_forward_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
assert y_fused.shape == y_ref.shape == (B, S, H, D)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"partial rotary forward cosine_sim={cos_sim:.6f} "
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
)
# The pass-through tail must equal the reference RMSNorm output bit-
# for-bit (any deviation would mean the kernel is touching it with a
# spurious rotation, which is the original bug class).
torch.testing.assert_close(
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
)
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference backward via the original primitives
x_ref = x_data.clone().requires_grad_(True)
w_ref = weight_init.clone()
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
y_ref.sum().backward()
# Fused backward
x_fused = x_data.clone().requires_grad_(True)
w_fused = weight_init.clone().requires_grad_(True)
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
y_fused.sum().backward()
cos_sim_x = torch.nn.functional.cosine_similarity(
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
)
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
@pytest.mark.parametrize(
"B,S,H,D,n_rot",
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
ids=["half_64", "quarter_256"],
)
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
eps = 1e-6
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
# rotary applied to the rotated slice.
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
normed = norm_ref(x_data)
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
y_ref.sum().backward()
w_fused = weight_init.clone().requires_grad_(True)
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
cos_sim_w = torch.nn.functional.cosine_similarity(
w_fused.grad.flatten().float(),
norm_ref.weight.grad.flatten().float(),
dim=0,
)
assert cos_sim_w > 0.995, (
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
)
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
"""Regression: passing cos/sin with shape == head_dim must still
match the full-rotary reference (the partial-rotary code path must
not perturb the existing full-rotary output)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 2, 16, 4, 64
eps = 1e-6
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
)
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
def test_validation_errors(self):
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
Triton crash deeper in the kernel)."""
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
B, S, H, D = 1, 4, 2, 64
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
# n_rot > head_dim
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="cannot exceed head_dim"):
fused_rms_norm_rope(x, w, cos_big, sin_big)
# cos/sin last-dim mismatch
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="same last dim"):
fused_rms_norm_rope(x, w, cos, sin)
# odd rotary dim
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
with pytest.raises(ValueError, match="must be even"):
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
class TestFusedRMSNormNoScale:
"""Tests for v_norm (RMSNorm without learnable scale)."""

View File

@@ -0,0 +1,220 @@
"""Tests for the Gemma 4 fused-attention monkey-patch.
These tests exercise the patched ``Gemma4TextAttention.forward`` against
the stock implementation it replaces. The hybrid Gemma 4 model intentionally
mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope
layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that:
1. The partial-rotary RMSNorm+RoPE path through the fused Triton kernel
gets exercised end-to-end (this is the bug originally documented in
``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``).
2. The fused forward must match the current transformers attention API,
where the decoder layer no longer passes a ``shared_kv_states`` kwarg
and shared kv lives on ``past_key_values.shared_layers``. An older
fused_forward signature would raise ``TypeError: ... missing 1
required positional argument: 'shared_kv_states'`` here.
"""
import pytest
import torch
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"),
]
pytest.importorskip(
"transformers.models.gemma4",
reason="fused_attn patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_attention():
"""Snapshot ``Gemma4TextAttention.forward`` and restore after the test
so the monkey-patch does not leak across the suite."""
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
saved = Gemma4TextAttention.forward
yield Gemma4TextAttention
Gemma4TextAttention.forward = saved
def _build_hybrid_config():
"""Tiny hybrid Gemma 4 config: one sliding layer + one full-attention
layer with proportional rope and partial_rotary_factor=0.25. This is
the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small
enough to fit on any GPU."""
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
global_head_dim=64,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
rope_parameters={
"sliding_attention": {
"rope_type": "default",
"rope_theta": 10000.0,
},
"full_attention": {
"rope_type": "proportional",
"rope_theta": 1000000.0,
"partial_rotary_factor": 0.25,
},
},
)
cfg._attn_implementation = "sdpa"
return cfg
def _build_model(seed=0):
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
torch.manual_seed(seed)
cfg = _build_hybrid_config()
return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval()
class TestFusedAttnSignature:
"""The fused forward must accept the same call shape as
``Gemma4TextDecoderLayer`` produces under the current transformers API
(no ``shared_kv_states`` kwarg)."""
def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention):
"""Regression for the API drift: decoder layer calls
``self.self_attn(hidden_states=..., position_embeddings=...,
attention_mask=..., position_ids=..., past_key_values=...)`` and
nothing else. A signature with a positional ``shared_kv_states``
used to raise ``TypeError`` here before reaching the kernel."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
patch_gemma4_fused_attn()
with torch.no_grad():
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
assert out.shape == (2, 16, 64)
assert torch.isfinite(out).all()
class TestFusedAttnPerLayerCorrectness:
"""Compare the patched attention layer to the stock implementation
on a single forward call. This isolates the fused kernel correctness
from cross-layer numerical drift."""
def _run_attention(self, model, layer_idx, hidden_states, position_ids):
"""Call ``Gemma4TextAttention.forward`` (whatever is currently
installed) for one layer and return the output."""
attn = model.layers[layer_idx].self_attn
layer_type = model.config.layer_types[layer_idx]
cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type)
out, _ = attn(
hidden_states=hidden_states,
position_embeddings=(cos, sin),
attention_mask=None,
)
return out
@pytest.mark.parametrize(
"layer_idx",
[0, 1],
ids=["sliding_head32", "global_head64_proportional"],
)
def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=1)
hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16)
pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1)
with torch.no_grad():
ref = self._run_attention(m, layer_idx, hs, pos)
patch_gemma4_fused_attn()
with torch.no_grad():
got = self._run_attention(m, layer_idx, hs, pos)
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
assert cos_sim > 0.999, (
f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}"
)
# bf16 precision: a few millis of absolute drift per element is
# acceptable for a Q/K/V projection pipeline. Anything larger is
# a real bug.
torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2)
class TestFusedAttnFullModel:
"""End-to-end model forward + backward through both layer types."""
def test_full_forward_matches_stock(self, restore_gemma4_attention):
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=2)
ids = torch.randint(0, 128, (2, 32), device="cuda")
mask = torch.ones(2, 32, dtype=torch.long, device="cuda")
with torch.no_grad():
ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
patch_gemma4_fused_attn()
with torch.no_grad():
got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone()
assert got.shape == ref.shape
assert torch.isfinite(got).all()
cos_sim = torch.nn.functional.cosine_similarity(
ref.flatten().float(), got.flatten().float(), dim=0
)
# End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16
# accumulates a small amount of numerical drift; we just want to
# pin that the two paths are computing the same function.
assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}"
def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention):
"""Gradients must propagate through the fused RMSNorm+RoPE kernels
for both the sliding and proportional-rope layers."""
from axolotl.monkeypatch.models.gemma4.fused_attn import (
patch_gemma4_fused_attn,
)
m = _build_model(seed=3).train()
patch_gemma4_fused_attn()
ids = torch.randint(0, 128, (2, 16), device="cuda")
mask = torch.ones(2, 16, dtype=torch.long, device="cuda")
out = m(input_ids=ids, attention_mask=mask).last_hidden_state
out.sum().backward()
# Both layers must accumulate gradients on q_norm.weight and
# k_norm.weight — that proves the fused kernel ran the backward.
for i, layer in enumerate(m.layers[:2]):
attn = layer.self_attn
assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad"
assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad"
assert attn.q_norm.weight.grad.isfinite().all()
assert attn.k_norm.weight.grad.isfinite().all()
assert attn.q_norm.weight.grad.abs().sum() > 0
assert attn.k_norm.weight.grad.abs().sum() > 0

View File

@@ -0,0 +1,343 @@
"""Tests for the Gemma 4 hybrid-attention mask fix.
These tests pin the single critical behavior: after installing the patch,
``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to
the underlying mask builder regardless of what the caller's config says.
This is what keeps full-attention (head_dim=512) global layers from
crashing at long sequence lengths — they need a 4D SDPA-format mask, not
the 2D FA2 mask that would be built from the model-level config.
The tests use a mocked ``create_causal_mask`` so they don't have to load
a real 26B Gemma 4 model or even have access to its weights. What matters
for the bug fix is which config is handed to the mask factory, not the
factory's actual output.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
pytest.importorskip(
"transformers.models.gemma4",
reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available",
)
@pytest.fixture
def restore_gemma4_module():
"""Snapshot ``modeling_gemma4.create_causal_mask`` and restore after
each test so patch state doesn't leak across the suite."""
from transformers.models.gemma4 import modeling_gemma4
saved = modeling_gemma4.create_causal_mask
yield modeling_gemma4
modeling_gemma4.create_causal_mask = saved
# Reset the module-level flag so the next test can re-install cleanly.
from axolotl.monkeypatch import gemma4_hybrid_mask
gemma4_hybrid_mask._PATCH_APPLIED = False
def test_patch_replaces_create_causal_mask(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
original = modeling_gemma4.create_causal_mask
assert patch_gemma4_hybrid_mask() is True
assert modeling_gemma4.create_causal_mask is not original
assert modeling_gemma4.create_causal_mask._axolotl_original is original, (
"patched wrapper must expose the original reference for teardown"
)
def test_patch_is_idempotent(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
wrapper_first = modeling_gemma4.create_causal_mask
# Second call must not re-wrap the already-wrapped function (which
# would leak the original reference through a chain of wrappers).
patch_gemma4_hybrid_mask()
wrapper_second = modeling_gemma4.create_causal_mask
assert wrapper_first is wrapper_second
def test_patched_mask_forces_sdpa_config(restore_gemma4_module):
"""Core invariant: when the patched wrapper is called with a config
that says ``flash_attention_2``, the underlying mask factory receives
a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``.
Without this, the full-attention global layers get a 2D FA2 mask and
crash at long seq lens with the [B, H, S, S] / [B, S] expand error.
"""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
# Swap in a mock BEFORE installing the patch so the wrapper captures
# it as the "original". The mock records every call so we can inspect
# what config got passed through.
mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
# Caller-supplied config says FA2 (that's the model-level setting).
caller_config = SimpleNamespace(
_attn_implementation="flash_attention_2",
head_dim=512,
some_other_attr="preserved",
)
result = modeling_gemma4.create_causal_mask(
caller_config,
inputs_embeds=None,
attention_mask=None,
past_key_values=None,
position_ids=None,
)
# Wrapper returned whatever the mock returned — no transformation of
# the result itself.
assert result == "mask_4d"
# The mock was called exactly once with a config whose
# ``_attn_implementation`` is sdpa, NOT the caller's fa2.
assert mock_factory.call_count == 1
(passed_config, *_), passed_kwargs = mock_factory.call_args
assert passed_config._attn_implementation == "sdpa"
# The wrapper must NOT mutate the caller's config in place — other
# mask builders (e.g. create_sliding_window_causal_mask) read from
# the same config and must still see fa2.
assert caller_config._attn_implementation == "flash_attention_2"
# Other attributes on the config must be preserved so the underlying
# factory has everything it needs (head_dim, rope_theta, vocab_size, ...).
assert passed_config.head_dim == 512
assert passed_config.some_other_attr == "preserved"
def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module):
"""The wrapper must forward positional + keyword args to the original
unchanged, so transformers' own call-site in Gemma4TextModel.forward
keeps working across minor transformers-version signature drift."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
mock_factory = MagicMock(return_value="mask")
modeling_gemma4.create_causal_mask = mock_factory
patch_gemma4_hybrid_mask()
caller_config = SimpleNamespace(_attn_implementation="flash_attention_2")
modeling_gemma4.create_causal_mask(
caller_config,
"positional_arg",
inputs_embeds="embeds",
attention_mask="mask_2d",
past_key_values="cache",
position_ids="positions",
or_mask_function="or_fn",
)
args, kwargs = mock_factory.call_args
# First positional (after config override) is preserved.
assert args[1] == "positional_arg"
# All kwargs are forwarded untouched.
assert kwargs["inputs_embeds"] == "embeds"
assert kwargs["attention_mask"] == "mask_2d"
assert kwargs["past_key_values"] == "cache"
assert kwargs["position_ids"] == "positions"
assert kwargs["or_mask_function"] == "or_fn"
def test_unpatch_restores_original(restore_gemma4_module):
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import (
patch_gemma4_hybrid_mask,
unpatch_gemma4_hybrid_mask,
)
sentinel = MagicMock(name="original")
modeling_gemma4.create_causal_mask = sentinel
patch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is not sentinel
unpatch_gemma4_hybrid_mask()
assert modeling_gemma4.create_causal_mask is sentinel
def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module):
from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask
# Should be a no-op, no exception.
unpatch_gemma4_hybrid_mask()
def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module):
"""Only ``create_causal_mask`` is overridden — the sliding-window
factory must remain bound to its original to preserve FA2 masks for
the sliding-attention layers. If we accidentally patch both, the
sliding layers get SDPA format and lose the FA2 speedup."""
modeling_gemma4 = restore_gemma4_module
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"):
pytest.skip("transformers version has no create_sliding_window_causal_mask")
sliding_before = modeling_gemma4.create_sliding_window_causal_mask
patch_gemma4_hybrid_mask()
sliding_after = modeling_gemma4.create_sliding_window_causal_mask
assert sliding_after is sliding_before
# ---------------------------------------------------------------------------
# Integration tests with a tiny randomly-initialized Gemma4TextModel.
#
# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text
# model with 2 layers (one sliding, one full_attention), apply the hybrid
# attention path end-to-end, and run a forward pass with a padded
# attention_mask at a long-ish seq len. The invariant we're pinning is that
# the full_attention layer does not crash with the
# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]"
# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k
# tokens in the FSDP2 training run.
# ---------------------------------------------------------------------------
def _build_tiny_gemma4_text_model():
"""Return a tiny randomly-initialized Gemma4TextModel with mixed layers."""
import torch
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Caller-supplied attn impl simulates the pilot config (fa2 at model
# level). The hybrid patch is what makes this survive long context.
cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later
torch.manual_seed(42)
model = Gemma4TextModel(cfg).eval()
return model, cfg
def _apply_hybrid_attn_inline(model, cfg):
"""Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does
to a model, without needing a full PatchManager / pydantic cfg."""
import copy
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
for layer_idx, layer in enumerate(model.layers):
if cfg.layer_types[layer_idx] != "sliding_attention":
attn = getattr(layer, "self_attn", None)
if attn is not None and hasattr(attn, "config"):
sdpa_cfg = copy.copy(attn.config)
sdpa_cfg._attn_implementation = "sdpa"
attn.config = sdpa_cfg
patch_gemma4_hybrid_mask()
def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module):
"""End-to-end invariant: with the hybrid attn patch applied, a tiny
Gemma4TextModel runs a forward at long context (1024 tokens) with
real padding in the attention mask, producing the expected output
shape. This exercises the actual code path that crashed the pilot
without needing a real 26B checkpoint or CUDA."""
import torch
model, cfg = _build_tiny_gemma4_text_model()
_apply_hybrid_attn_inline(model, cfg)
B, S = 2, 1024
input_ids = torch.randint(0, cfg.vocab_size, (B, S))
attn_mask = torch.ones(B, S, dtype=torch.long)
# Pad positions in the second row. Without padding, SDPA falls back to
# ``is_causal=True`` with ``mask=None`` — we need a materialized 4D
# mask to exercise the actual bug site.
attn_mask[1, S // 2 :] = 0
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attn_mask)
assert out.last_hidden_state.shape == (B, S, cfg.hidden_size)
assert torch.isfinite(out.last_hidden_state).all()
def test_patched_create_causal_mask_returns_4d_for_real_config(
restore_gemma4_module,
):
"""Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper
and verify the returned mask is a 4D tensor — which is the shape the
SDPA-patched global layers need. Without the patch and with a
caller-supplied FA2 config this would return a 2D mask and the layer
would crash at long context."""
import torch
from transformers.cache_utils import DynamicCache
from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
patch_gemma4_hybrid_mask()
modeling_gemma4 = restore_gemma4_module
cfg = Gemma4TextConfig(
vocab_size=128,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
head_dim=32,
layer_types=["sliding_attention", "full_attention"],
sliding_window=64,
max_position_embeddings=2048,
hidden_size_per_layer_input=16,
vocab_size_per_layer_input=128,
)
# Simulate the pilot: caller says flash_attention_2, but global layers
# were switched to SDPA per-layer. Without the patch, create_causal_mask
# would return an FA2 2D mask here and the SDPA layer would crash.
cfg._attn_implementation = "flash_attention_2"
B, S = 2, 1024
inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32)
attention_mask = torch.ones((B, S), dtype=torch.long)
attention_mask[1, S // 2 :] = 0 # force the 4D materialized path
position_ids = torch.arange(S).unsqueeze(0).expand(B, -1)
past_key_values = DynamicCache(config=cfg)
mask = modeling_gemma4.create_causal_mask(
config=cfg,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
)
assert mask is not None
assert isinstance(mask, torch.Tensor)
assert mask.dim() == 4, (
f"expected a 4D SDPA-format mask, got {mask.dim()}D "
f"shape={tuple(mask.shape)}. The full_attention global layers need "
"this shape or they crash at long context."
)
assert mask.shape[0] == B
assert mask.shape[-1] == S
assert mask.shape[-2] == S
# Caller's config must be untouched — other code paths still read it.
assert cfg._attn_implementation == "flash_attention_2"