diff --git a/src/axolotl/core/trainers/grpo/async_trainer.py b/src/axolotl/core/trainers/grpo/async_trainer.py index 47c3a07ad..334db8bb2 100644 --- a/src/axolotl/core/trainers/grpo/async_trainer.py +++ b/src/axolotl/core/trainers/grpo/async_trainer.py @@ -242,6 +242,87 @@ class ProducerConfig: ) +class _GroupShardedSampler: + """Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups. + + ``RepeatSampler`` yields ``num_generations`` consecutive copies of + each prompt, forming a GRPO group. For distributed training each + rank must see a disjoint slice of prompts (otherwise every rank + dogpiles on the first 1/world_size of the batch) while keeping each + group intact on a single rank so advantage normalization sees all + peer generations. + + ``accelerator.prepare(DataLoader)`` does not handle this correctly + for custom samplers with ``split_batches=False`` (the default): it + leaves the sampler alone and every rank replays identical indices. + This wrapper fixes that by consuming the inner sampler's full + output, chunking it into ``num_generations``-sized groups, and + round-robining whole groups across ranks. + + Intended to be used ONLY when distributed training is active + (``num_replicas > 1``); for single-rank it is a no-op but still + correct. + """ + + def __init__( + self, + inner: Any, + num_generations: int, + rank: int, + num_replicas: int, + ): + if num_generations < 1: + raise ValueError(f"num_generations must be >= 1, got {num_generations}") + if num_replicas < 1: + raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") + if not (0 <= rank < num_replicas): + raise ValueError( + f"rank must be in [0, {num_replicas}), got {rank}" + ) + self.inner = inner + self.num_generations = num_generations + self.rank = rank + self.num_replicas = num_replicas + + def __iter__(self): + all_indices = list(self.inner) + if len(all_indices) % self.num_generations != 0: + raise ValueError( + f"inner sampler yielded {len(all_indices)} indices, " + f"not a multiple of num_generations={self.num_generations}" + ) + # Chunk the flat index sequence into groups of num_generations + # consecutive indices. ``RepeatSampler`` guarantees that each + # group contains num_generations copies of the same prompt id. + groups = [ + all_indices[i : i + self.num_generations] + for i in range(0, len(all_indices), self.num_generations) + ] + # Round-robin whole groups across ranks. Round-robin (vs. + # contiguous chunking) preserves approximate shuffled order on + # each rank even when the group count is small relative to the + # world size. + for group in groups[self.rank :: self.num_replicas]: + yield from group + + def __len__(self): + try: + inner_len = len(self.inner) + except TypeError: + # Non-sized inner sampler — we can't know the per-rank + # length without materializing. Return 0 as a hint that the + # DataLoader should fall back to iteration. + return 0 + total_groups = inner_len // self.num_generations + # Ceiling division for the trailing groups that don't divide + # evenly — extra groups go to the first ``total_groups % + # num_replicas`` ranks, matching the round-robin above. + my_groups = ( + total_groups + self.num_replicas - self.rank - 1 + ) // self.num_replicas + return my_groups * self.num_generations + + class DataProducer(ABC): """Abstract base class for online data producers. @@ -556,6 +637,34 @@ class GRPODataProducer(BaseDataProducer): seed=self._seed, ) + # Shard the sampler across distributed ranks so each rank sees + # a disjoint slice of prompts. ``RepeatSampler`` groups each + # prompt with ``num_generations`` consecutive copies — our + # wrapper round-robins WHOLE groups across ranks so all + # generations of a given prompt stay on the same rank (needed + # for GRPO advantage normalization within a group). + # + # Without this, ``accelerator.prepare(dl)`` with the default + # ``split_batches=False`` leaves the custom sampler alone, so + # every rank iterates the identical index sequence and the + # cluster dogpiles on the first 1/world_size of the prompts. + num_replicas = max(1, trainer.accelerator.num_processes) + if num_replicas > 1: + sampler = _GroupShardedSampler( + inner=sampler, + num_generations=self._num_generations, + rank=trainer.accelerator.process_index, + num_replicas=num_replicas, + ) + logger.info( + "[RANK:%d] _GroupShardedSampler active " + "(num_replicas=%d, num_generations=%d, gen_batch=%d)", + trainer.accelerator.process_index, + num_replicas, + self._num_generations, + self._generation_batch_size, + ) + # Use identity collator (same as stock GRPOTrainer) def _identity(x): return x @@ -574,12 +683,11 @@ class GRPODataProducer(BaseDataProducer): rank=trainer.args.process_index, ), ) - self._prompt_dl = trainer.accelerator.prepare(dl) - - # Don't let accelerator track this dataloader - acc_dls = trainer.accelerator._dataloaders - if self._prompt_dl in acc_dls: - acc_dls.remove(self._prompt_dl) + # Skip accelerator.prepare — we're handling per-rank sharding + # ourselves via ``_GroupShardedSampler``. ``prepare()`` would + # otherwise try to wrap the DataLoader with its own sharding + # logic which does not understand our group structure. + self._prompt_dl = dl self._prompt_iter = iter(self._prompt_dl) diff --git a/src/axolotl/integrations/nemo_gym/data_producer.py b/src/axolotl/integrations/nemo_gym/data_producer.py index 64b76d780..1cbe5ad71 100644 --- a/src/axolotl/integrations/nemo_gym/data_producer.py +++ b/src/axolotl/integrations/nemo_gym/data_producer.py @@ -110,11 +110,35 @@ class NemoGymDataProducer(GRPODataProducer): item["agent_ref"] = full_item["agent_ref"] dataset_items.append(item) - # Expand by num_generations (agent produces one rollout per call) - expanded_items = [] - for item in dataset_items: - for _ in range(self._num_generations): - expanded_items.append(item) + # NOTE: do NOT re-expand by num_generations here. + # ``RepeatSampler(mini_repeat_count=num_generations)`` already + # yields ``num_generations`` consecutive copies of each unique + # prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank * + # num_generations)`` items — one entry per rollout. Expanding + # again here would fire ``num_generations^2`` rollouts per + # prompt per rank and make every step dogpile on a handful of + # tasks. + expanded_items = dataset_items + + # Diagnostic: log what this rank is about to fire. + try: + import collections + iid_counts = collections.Counter() + for it in dataset_items: + iid_counts[ + (it.get("responses_create_params", {}).get("metadata") or {}).get( + "instance_id" + ) + ] += 1 + LOG.info( + "[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s", + trainer.accelerator.process_index, + len(dataset_items), + len(iid_counts), + list(iid_counts.most_common(5)), + ) + except Exception: + pass # Call NeMo Gym agents loop = asyncio.new_event_loop() @@ -140,6 +164,7 @@ class NemoGymDataProducer(GRPODataProducer): logprobs_list = [] rewards_list = [] + num_turns_list: list[int] = [] for resp in responses: parsed = _parse_agent_response(resp, eos_token_id) prompt_ids_list.append(parsed["prompt_ids"]) @@ -147,6 +172,7 @@ class NemoGymDataProducer(GRPODataProducer): env_mask_list.append(parsed["env_mask"]) logprobs_list.append(parsed["logprobs"]) rewards_list.append(parsed["reward"]) + num_turns_list.append(parsed.get("num_turns", 0)) # Pad to tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -179,22 +205,49 @@ class NemoGymDataProducer(GRPODataProducer): tool_mask = [torch.tensor(m, device=device) for m in env_mask_list] tool_mask = pad(tool_mask, padding_value=1, padding_side="right") - # Inject rewards into inputs so _compute_deferred_scores can use them - # The deferred scoring path calls _calculate_rewards which reads reward_funcs. - # Our passthrough reward_fn reads "env_reward" from kwargs. + # Inject per-rollout reward + num_turns into each input. Since + # ``RepeatSampler`` already yields ``num_generations`` copies of + # each prompt, ``inputs`` has ONE entry per rollout (matching + # ``rewards_list`` 1:1). No per-prompt grouping happens here — + # GRPO advantage normalization is the trainer's job downstream. + assert len(inputs) == len(rewards_list), ( + f"rewards/inputs length mismatch: " + f"{len(rewards_list)} rewards vs {len(inputs)} inputs" + ) for i, inp in enumerate(inputs): - # Each input gets rewards for its num_generations rollouts - start = i * self._num_generations - end = start + self._num_generations - inp["env_reward"] = rewards_list[start:end] + inp["env_reward"] = rewards_list[i] + inp["num_turns"] = num_turns_list[i] + + # One expanded_input per rollout (already correct count because + # inputs has num_generations copies baked in by the sampler). + expanded_inputs = [dict(inp) for inp in inputs] + + # Log rollout-level stats to wandb from rank 0. These are the + # true agent-side metrics (not the tokenized TRL view) — so + # num_turns reflects how many /run iterations each rollout + # actually took before finishing or hitting max_turns. + if is_main and num_turns_list: + try: + import wandb + + if wandb.run is not None: + import statistics as _stats + + nonzero = sum(1 for r in rewards_list if r > 0) + log_payload = { + "rollout/num_turns/mean": float(_stats.mean(num_turns_list)), + "rollout/num_turns/min": float(min(num_turns_list)), + "rollout/num_turns/max": float(max(num_turns_list)), + "rollout/reward/mean": float(_stats.mean(rewards_list)), + "rollout/reward/nonzero_frac": ( + nonzero / len(rewards_list) if rewards_list else 0.0 + ), + "rollout/n_samples": float(len(rewards_list)), + } + wandb.log(log_payload, commit=False) + except Exception as exc: # never let metric logging break training + LOG.warning("rollout wandb log failed: %s", exc) - # Expand inputs to match expanded rollouts (num_generations copies) - expanded_inputs = [] - for inp in inputs: - for g in range(self._num_generations): - expanded_inp = dict(inp) - expanded_inp["env_reward"] = inp["env_reward"][g] - expanded_inputs.append(expanded_inp) # Decode completions for reward functions completions = trainer.processing_class.batch_decode( diff --git a/src/axolotl/kernels/gemma4_fused_rope.py b/src/axolotl/kernels/gemma4_fused_rope.py index f3b68e603..f98e9a3de 100644 --- a/src/axolotl/kernels/gemma4_fused_rope.py +++ b/src/axolotl/kernels/gemma4_fused_rope.py @@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel( RSTD_ptr, RSTD_row_stride, n_cols, + n_rot, n_heads, eps, HAS_WEIGHT: tl.constexpr, @@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel( ): """ Fused forward: - x_norm = x / rms(x) [* weight] (RMSNorm) - y = x_norm * cos + rotate_half(x_norm) * sin (RoPE) + x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols) + y[..., :n_rot] = rope(x_norm[..., :n_rot]) + y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary) - rotate_half swaps first/second halves and negates the first: - rotate_half([a, b]) = [-b, a] + rotate_half swaps first/second halves and negates the first, restricted + to the rotary span [0, n_rot): + rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2 + + For the partial-rotary pass-through region we load cos with default 1.0 + and sin with default 0.0 outside [0, n_rot), so the same formula + `Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`. cos/sin are indexed by row_idx // n_heads to handle per-head broadcast - (cos/sin have shape (B*S, D) while X has shape (B*S*H, D)). + (cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)). """ row_idx = tl.program_id(0).to(tl.int64) - # cos/sin row: divide by n_heads since cos/sin are (B*S, D) + # cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot) cs_row_idx = row_idx // n_heads col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - half_dim = n_cols // 2 + rot_mask_col = col_offsets < n_rot + half_rot = n_rot // 2 # Load input row X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0) X_dtype = X_row.dtype X_fp32 = X_row.to(tl.float32) - # RMSNorm: compute 1/rms + # RMSNorm: compute 1/rms over the full row (rotary + pass-through) mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols rstd = rsqrt(mean_sq + eps) tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd) @@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel( W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32) X_norm = X_norm * W_row - # RoPE: load cos/sin (broadcast across heads) + # RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get + # cos=1, sin=0 so the formula leaves X_norm untouched. cos_row = tl.load( - COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, + mask=rot_mask_col, + other=1.0, ).to(tl.float32) sin_row = tl.load( - SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0 + SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, + mask=rot_mask_col, + other=0.0, ).to(tl.float32) - # rotate_half: for col < half_dim, take -X_norm[col + half_dim] - # for col >= half_dim, take X_norm[col - half_dim] + # rotate_half within [0, n_rot): + # for col < half_rot: take -X_norm[col + half_rot] + # for col in [half_rot, n_rot): take X_norm[col - half_rot] + # For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out). rot_offsets = tl.where( - col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot ) - rot_mask = rot_offsets < n_cols + rot_load_mask = (rot_offsets < n_cols) & rot_mask_col X_rot = tl.load( - X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0 + X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0 ).to(tl.float32) # Re-normalize the rotated values X_rot_norm = X_rot * rstd if HAS_WEIGHT: - W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to( - tl.float32 - ) + W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32) X_rot_norm = X_rot_norm * W_rot # Negate the first half (rotate_half negates x2, which becomes the first half) - sign = tl.where(col_offsets < half_dim, -1.0, 1.0) + sign = tl.where(col_offsets < half_rot, -1.0, 1.0) X_rot_norm = X_rot_norm * sign # Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin @@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel( dW_row_stride, n_rows, n_cols, + n_rot, n_heads, rows_per_program, HAS_WEIGHT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ - Backward for Y = RoPE(RMSNorm(X, W)) + Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary + (`n_rot <= n_cols`). + + For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the + output is just the normalized row, so dN[col] = dY[col] (achieved by + loading cos with default 1.0 and forcing the rotate-half contribution + to zero outside the rotary span). + cos/sin indexed by row_idx // n_heads for per-head broadcast. """ row_block_id = tl.program_id(0).to(tl.int64) @@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel( row_end = min((row_block_id + 1) * rows_per_program, n_rows) col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - half_dim = n_cols // 2 + rot_mask_col = col_offsets < n_rot + half_rot = n_rot // 2 dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) @@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel( rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride) cos_row = tl.load( - COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0 + COS_ptr + cs_row_idx * COS_row_stride + col_offsets, + mask=rot_mask_col, + other=1.0, ).to(tl.float32) - # dN = dY * cos + rotate_half^T(dY * sin) + # dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span) # rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half) # - # Compute rotate_half_transpose(dY * sin) by loading dY and sin at - # rotated offsets directly: dY[rot] * sin[rot] * adj_sign - # This is equivalent to rotating (dY * sin) because the rotation - # just permutes which elements are multiplied. + # For col >= n_rot the formula must collapse to dN = dY (since the + # forward is just a pass-through). cos defaults to 1.0 above; the + # rotate-half contribution is masked to zero below. rot_offsets = tl.where( - col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim + col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot ) - rot_mask = rot_offsets < n_cols + rot_load_mask = (rot_offsets < n_cols) & rot_mask_col dY_rot = tl.load( dY_ptr + row_idx * dY_row_stride + rot_offsets, - mask=rot_mask & mask, + mask=rot_load_mask, other=0, ).to(tl.float32) sin_rot = tl.load( SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets, - mask=rot_mask & mask, + mask=rot_load_mask, other=0, ).to(tl.float32) - adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0) - dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign + adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0) + rotate_term = dY_rot * sin_rot * adj_sign + # Zero out rotate-half contribution outside the rotary span. + rotate_term = tl.where(rot_mask_col, rotate_term, 0.0) + dN = dY_row * cos_row + rotate_term # Pre-weight normalized: n = rstd * x n = X_row * rstd @@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel( ) -def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): +def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot): """ Args: X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D) W: (head_dim,) or None — RMSNorm weight - cos: (B*S, head_dim) — position embeddings (broadcast across heads) - sin: (B*S, head_dim) — position embeddings (broadcast across heads) + cos: (B*S, n_rot) — position embeddings (broadcast across heads) + sin: (B*S, n_rot) — position embeddings (broadcast across heads) eps: float n_heads: int — number of attention heads (for cos/sin indexing) + n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for + partial rotary). Must be even and ``<= head_dim``. Returns: Y, X_saved, RSTD, BLOCK_SIZE, num_warps """ @@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): RSTD, RSTD.stride(0), n_cols, + n_rot, n_heads, eps, HAS_WEIGHT=has_weight, @@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads): return Y, X, RSTD, BLOCK_SIZE, num_warps -def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps): +def rms_norm_rope_backward( + dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps +): n_rows, n_cols = dY.shape has_weight = W is not None @@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa _dW.stride(0), n_rows, n_cols, + n_rot, n_heads, rows_per_program, HAS_WEIGHT=has_weight, @@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa class FusedRMSNormRoPEFunction(torch.autograd.Function): @staticmethod @ensure_contiguous - def forward(ctx, X, W, cos, sin, eps, n_heads): + def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot): """ - X: (B*S*H, head_dim) - W: (head_dim,) or None - cos: (B*S, head_dim) — broadcast across heads - sin: (B*S, head_dim) — broadcast across heads + X: (B*S*H, head_dim) + W: (head_dim,) or None + cos: (B*S, n_rot) — broadcast across heads + sin: (B*S, n_rot) — broadcast across heads n_heads: int + n_rot: int — rotary dim (<= head_dim) """ Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward( X, @@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function): sin, eps, n_heads, + n_rot, ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.n_heads = n_heads + ctx.n_rot = n_rot ctx.has_weight = W is not None ctx.save_for_backward(X_saved, W, cos, sin, RSTD) return Y @@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function): sin, RSTD, ctx.n_heads, + ctx.n_rot, ctx.BLOCK_SIZE, ctx.num_warps, ) - return dX, dW, None, None, None, None + return dX, dW, None, None, None, None, None def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6): """ - Apply fused RMSNorm + RoPE. + Apply fused RMSNorm + (partial) RoPE. Args: x: (batch, seq_len, num_heads, head_dim) — after projection + view weight: (head_dim,) — RMSNorm weight, or None for no-scale norm - cos: (batch, seq_len, head_dim) — from RotaryEmbedding - sin: (batch, seq_len, head_dim) — from RotaryEmbedding + cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot`` + must be even and ``<= head_dim``. When ``n_rot < head_dim`` + the trailing ``head_dim - n_rot`` columns are RMSNorm-only + (partial-rotary pass-through), matching stock Gemma 4 with + ``partial_rotary_factor < 1.0``. + sin: (batch, seq_len, n_rot) — same shape as ``cos`` eps: float — RMSNorm epsilon Returns: @@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6): """ shape = x.shape # (B, S, H, D) B, S, H, D = shape + n_rot = cos.shape[-1] + if sin.shape[-1] != n_rot: + raise ValueError( + f"cos and sin must have the same last dim, got cos={cos.shape[-1]} " + f"sin={sin.shape[-1]}" + ) + if n_rot > D: + raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})") + if n_rot % 2 != 0: + raise ValueError(f"rotary dim must be even, got {n_rot}") + # Flatten to 2D: (B*S*H, D) x_flat = x.reshape(-1, D).contiguous() - # Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast - # by dividing the row_idx by H to get the cos/sin row - cos_flat = cos.reshape(B * S, D).contiguous() - sin_flat = sin.reshape(B * S, D).contiguous() + # cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when + # all sequences share the same rotary positions). The kernel needs a + # dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly + # onto a single (b, s) pair, so expand-then-contiguous to materialize + # the per-batch broadcast. Expand is a no-op when B == cos.shape[0]. + if cos.shape[0] != B: + if cos.shape[0] != 1: + raise ValueError( + f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal " + f"to x batch dim ({B})" + ) + cos = cos.expand(B, S, n_rot) + sin = sin.expand(B, S, n_rot) + cos_flat = cos.reshape(B * S, n_rot).contiguous() + sin_flat = sin.reshape(B * S, n_rot).contiguous() - y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H) + y_flat = FusedRMSNormRoPEFunction.apply( + x_flat, weight, cos_flat, sin_flat, eps, H, n_rot + ) return y_flat.view(shape) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 41fc35e6e..ccaf04bfd 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -156,6 +156,14 @@ class PatchManager: # which would clobber any earlier fix. self._fix_nemotron_h_conversion_mapping() + # Gemma 4 hybrid attention runs here in post-build (NOT post-load): + # the per-layer ``self_attn.config._attn_implementation="sdpa"`` + # override needs to walk the raw model tree, which is broken by + # the post-load PEFT wrapping. The accompanying + # ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and + # installation-time-independent, so both halves of the fix live + # cleanly in the same call even though one is instance-scoped + # and the other is module-scoped. self._apply_gemma_hybrid_attention(model) self._finalize_moe_expert_quantization(model) @@ -173,12 +181,23 @@ class PatchManager: which exceeds flash attention's supported size. This patch loads the model with flash_attention_2 for the sliding window layers (head_dim=256), then gives each global layer a shallow-copied config with _attn_implementation="sdpa". + + We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask` + which fixes the corresponding mask construction inside + ``Gemma4TextModel.forward``. Without it, the per-layer SDPA config + override is not enough — the forward still builds a 2D FA2-format mask + at the model level and the SDPA layers crash at long context lengths + with ``RuntimeError: The expanded size of the tensor ... must match``. """ if not self.cfg.gemma4_hybrid_attn_impl: return import copy + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + # Navigate to the module that has 'layers' - varies by model structure: # Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers # Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers @@ -392,6 +411,14 @@ class PatchManager: patch_qwen3_5_vlm_flash_attention() if self.cfg.model_config_type in ("gemma4", "gemma4_text"): + # The fused attn path is now compatible with + # ``gemma4_hybrid_attn_impl``: the kernel handles partial + # rotary (cos.shape[-1] < head_dim) and the fused forward + # mirrors the current ``Gemma4TextAttention.forward`` API + # for shared kv (read from / write to + # ``past_key_values.shared_layers``). See + # ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md`` + # for the history. from axolotl.monkeypatch.models.gemma4.fused_attn import ( patch_gemma4_fused_attn, ) diff --git a/src/axolotl/monkeypatch/gemma4_hybrid_mask.py b/src/axolotl/monkeypatch/gemma4_hybrid_mask.py new file mode 100644 index 000000000..17b8cf053 --- /dev/null +++ b/src/axolotl/monkeypatch/gemma4_hybrid_mask.py @@ -0,0 +1,115 @@ +"""Hybrid attention mask fix for Gemma 4. + +Gemma 4 has full-attention (global) layers with ``head_dim=512`` which +exceeds flash-attention-2's supported size. Axolotl's hybrid-attention +patch in ``patch_manager._apply_gemma_hybrid_attention`` works around +this by forcing ``_attn_implementation="sdpa"`` on each global layer's +``self_attn.config``, leaving sliding-window layers on FA2. + +The per-layer config override alone is insufficient, however: +``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict +using the **model-level** config and passes the mapped mask to each +decoder layer. With FA2 still set at the model level, the ``full_attention`` +entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask. +The global layers then fail with:: + + RuntimeError: The expanded size of the tensor (S) must match the existing + size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor + sizes: [B, S] + +...when the sequence length grows past roughly 7k tokens. + +This module fixes the symptom by monkey-patching ``create_causal_mask`` in +``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT +the original in ``masking_utils``. The wrapper forces +``_attn_implementation="sdpa"`` on a shallow-copied config before calling +through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward`` +is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left +alone, so sliding-window layers continue to receive FA2-format masks. + +The patch is idempotent. Install once per process, before any Gemma 4 +forward pass runs. +""" + +from __future__ import annotations + +import copy +from typing import Any + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +_PATCH_APPLIED = False + + +def patch_gemma4_hybrid_mask() -> bool: + """Install the Gemma 4 hybrid-attention mask fix. + + Returns ``True`` if the patch was installed (or was already installed), + ``False`` if the target module could not be imported (e.g. transformers + version predates Gemma 4) — in which case nothing is done and the + caller can continue unaffected. + """ + global _PATCH_APPLIED + if _PATCH_APPLIED: + return True + + try: + from transformers.models.gemma4 import modeling_gemma4 + except ImportError: + LOG.debug( + "gemma4_hybrid_mask: transformers.models.gemma4 not importable, " + "skipping. This is fine for non-Gemma4 training." + ) + return False + + if not hasattr(modeling_gemma4, "create_causal_mask"): + LOG.warning( + "gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' " + "binding, skipping. Transformers API may have changed." + ) + return False + + original = modeling_gemma4.create_causal_mask + + def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any): + """Wrapper that forces SDPA format for the full-attention mask. + + The global layers were patched to SDPA by + ``_apply_gemma_hybrid_attention``, so their mask must be 4D. The + original ``create_causal_mask`` dispatches on + ``config._attn_implementation``; we shadow that with a local + override. + """ + sdpa_config = copy.copy(config) + sdpa_config._attn_implementation = "sdpa" + return original(sdpa_config, *args, **kwargs) + + # Preserve the original reference on the wrapper for tests / teardown. + hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined] + + modeling_gemma4.create_causal_mask = hybrid_create_causal_mask + _PATCH_APPLIED = True + LOG.info( + "gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to " + "force SDPA-format masks for full-attention layers" + ) + return True + + +def unpatch_gemma4_hybrid_mask() -> None: + """Restore the original ``create_causal_mask``. Useful for tests.""" + global _PATCH_APPLIED + if not _PATCH_APPLIED: + return + try: + from transformers.models.gemma4 import modeling_gemma4 + except ImportError: + _PATCH_APPLIED = False + return + current = modeling_gemma4.create_causal_mask + original = getattr(current, "_axolotl_original", None) + if original is not None: + modeling_gemma4.create_causal_mask = original + _PATCH_APPLIED = False diff --git a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py index 7cb5c6beb..4a171db8f 100644 --- a/src/axolotl/monkeypatch/models/gemma4/fused_attn.py +++ b/src/axolotl/monkeypatch/models/gemma4/fused_attn.py @@ -30,7 +30,6 @@ def _make_fused_forward(original_forward): hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: torch.Tensor | None, - shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]], past_key_values=None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -66,8 +65,14 @@ def _make_fused_forward(original_forward): query_states = query_states.transpose(1, 2) # ---- K/V path ---- - if self.is_kv_shared_layer: - key_states, value_states = shared_kv_states[self.kv_shared_layer_index] + # Current transformers stores shared kv on `past_key_values.shared_layers` + # (the legacy `shared_kv_states` decoder kwarg was removed). We mirror + # the stock attention forward exactly so the dispatch is identical + # regardless of whether the model was patched. + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[ + self.kv_shared_layer_index + ] key_states = key_states.to(query_states.device) value_states = value_states.to(query_states.device) else: @@ -101,12 +106,18 @@ def _make_fused_forward(original_forward): value_states = fused_rms_norm_noscale(value_states, eps=eps) value_states = value_states.transpose(1, 2) - if past_key_values is not None and not self.is_kv_shared_layer: - key_states, value_states = past_key_values.update( - key_states, value_states, self.layer_idx - ) - if self.store_full_length_kv: - shared_kv_states[self.layer_idx] = key_states, value_states + if past_key_values is not None: + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx + ) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = ( + key_states, + value_states, + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/tests/kernels/test_gemma4_fused_rope.py b/tests/kernels/test_gemma4_fused_rope.py index 7daedd612..297bb2527 100644 --- a/tests/kernels/test_gemma4_fused_rope.py +++ b/tests/kernels/test_gemma4_fused_rope.py @@ -38,6 +38,30 @@ def _reference_norm_noscale(x, eps): return norm(x) +def _reference_partial_norm_rope(x, weight, cos, sin, eps): + """Reference: Gemma4RMSNorm over the full head_dim, then stock + ``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with + the trailing columns passed through unchanged. Mirrors how Llama-style + partial rotary is layered on top of the stock RMSNorm + RoPE primitives. + """ + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4RMSNorm, + apply_rotary_pos_emb, + ) + + D = x.shape[-1] + n_rot = cos.shape[-1] + norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype) + norm.weight.data.copy_(weight) + normed = norm(x) + if n_rot == D: + return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2) + x_rot = normed[..., :n_rot] + x_pass = normed[..., n_rot:] + rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2) + return torch.cat([rotated, x_pass], dim=-1) + + @pytest.fixture( params=[ (2, 64, 32, 256), # sliding window layer shape @@ -194,6 +218,172 @@ class TestFusedRMSNormRoPEBackward: assert w.grad.abs().sum() > 0, "w.grad is all zeros" +class TestFusedRMSNormRoPEPartialRotary: + """Partial-rotary: cos/sin last dim is smaller than head_dim. + + Compares against the original primitives (`Gemma4RMSNorm` + + `apply_rotary_pos_emb`) applied to the rotated slice with the trailing + columns passed through. Without the kernel fix this used to crash with + `RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`. + """ + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [ + (2, 16, 4, 64, 32), # half rotary (Llama-style 0.5) + (2, 16, 4, 64, 16), # quarter rotary + (2, 32, 8, 128, 64), # half rotary, larger heads + (1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial + (1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path + ], + ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"], + ) + def test_forward_matches_reference(self, B, S, H, D, n_rot): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(D, device="cuda", dtype=torch.bfloat16) + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + + y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps) + y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps) + + assert y_fused.shape == y_ref.shape == (B, S, H, D) + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, ( + f"partial rotary forward cosine_sim={cos_sim:.6f} " + f"(B={B},S={S},H={H},D={D},n_rot={n_rot})" + ) + + # The pass-through tail must equal the reference RMSNorm output bit- + # for-bit (any deviation would mean the kernel is touching it with a + # spurious rotation, which is the original bug class). + torch.testing.assert_close( + y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2 + ) + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)], + ids=["half_64", "quarter_256"], + ) + def test_x_grad_matches_reference(self, B, S, H, D, n_rot): + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + + # Reference backward via the original primitives + x_ref = x_data.clone().requires_grad_(True) + w_ref = weight_init.clone() + y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps) + y_ref.sum().backward() + + # Fused backward + x_fused = x_data.clone().requires_grad_(True) + w_fused = weight_init.clone().requires_grad_(True) + y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps) + y_fused.sum().backward() + + cos_sim_x = torch.nn.functional.cosine_similarity( + x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0 + ) + assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}" + + @pytest.mark.parametrize( + "B,S,H,D,n_rot", + [(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)], + ids=["half_64", "quarter_256"], + ) + def test_weight_grad_matches_reference(self, B, S, H, D, n_rot): + from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm + + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + eps = 1e-6 + cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16) + weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16) + x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + + # Reference: Gemma4RMSNorm whose .weight collects grads, then partial + # rotary applied to the rotated slice. + norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16) + norm_ref.weight = torch.nn.Parameter(weight_init.clone()) + normed = norm_ref(x_data) + from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb + + rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2) + y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1) + y_ref.sum().backward() + + w_fused = weight_init.clone().requires_grad_(True) + fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward() + + cos_sim_w = torch.nn.functional.cosine_similarity( + w_fused.grad.flatten().float(), + norm_ref.weight.grad.flatten().float(), + dim=0, + ) + assert cos_sim_w > 0.995, ( + f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}" + ) + + def test_full_rotary_unchanged_when_n_rot_equals_d(self): + """Regression: passing cos/sin with shape == head_dim must still + match the full-rotary reference (the partial-rotary code path must + not perturb the existing full-rotary output).""" + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = 2, 16, 4, 64 + eps = 1e-6 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + weight = torch.randn(D, device="cuda", dtype=torch.bfloat16) + cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16) + + y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps) + y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps) + cos_sim = torch.nn.functional.cosine_similarity( + y_ref.flatten().float(), y_fused.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}" + + def test_validation_errors(self): + """Wrapper rejects misshaped inputs cleanly (instead of a cryptic + Triton crash deeper in the kernel).""" + from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope + + B, S, H, D = 1, 4, 2, 64 + x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16) + w = torch.randn(D, device="cuda", dtype=torch.bfloat16) + + # n_rot > head_dim + cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16) + sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="cannot exceed head_dim"): + fused_rms_norm_rope(x, w, cos_big, sin_big) + + # cos/sin last-dim mismatch + cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16) + sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="same last dim"): + fused_rms_norm_rope(x, w, cos, sin) + + # odd rotary dim + cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16) + sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="must be even"): + fused_rms_norm_rope(x, w, cos_odd, sin_odd) + + class TestFusedRMSNormNoScale: """Tests for v_norm (RMSNorm without learnable scale).""" diff --git a/tests/monkeypatch/test_gemma4_fused_attn.py b/tests/monkeypatch/test_gemma4_fused_attn.py new file mode 100644 index 000000000..ce8431477 --- /dev/null +++ b/tests/monkeypatch/test_gemma4_fused_attn.py @@ -0,0 +1,220 @@ +"""Tests for the Gemma 4 fused-attention monkey-patch. + +These tests exercise the patched ``Gemma4TextAttention.forward`` against +the stock implementation it replaces. The hybrid Gemma 4 model intentionally +mixes a sliding (`head_dim=32`) layer with a full-attention proportional-rope +layer (`global_head_dim=64`, `partial_rotary_factor=0.25`) so that: + + 1. The partial-rotary RMSNorm+RoPE path through the fused Triton kernel + gets exercised end-to-end (this is the bug originally documented in + ``GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``). + 2. The fused forward must match the current transformers attention API, + where the decoder layer no longer passes a ``shared_kv_states`` kwarg + and shared kv lives on ``past_key_values.shared_layers``. An older + fused_forward signature would raise ``TypeError: ... missing 1 + required positional argument: 'shared_kv_states'`` here. +""" + +import pytest +import torch + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + +pytest.importorskip( + "transformers.models.gemma4", + reason="fused_attn patch only matters when Gemma 4 is available", +) + + +@pytest.fixture +def restore_gemma4_attention(): + """Snapshot ``Gemma4TextAttention.forward`` and restore after the test + so the monkey-patch does not leak across the suite.""" + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention + + saved = Gemma4TextAttention.forward + yield Gemma4TextAttention + Gemma4TextAttention.forward = saved + + +def _build_hybrid_config(): + """Tiny hybrid Gemma 4 config: one sliding layer + one full-attention + layer with proportional rope and partial_rotary_factor=0.25. This is + the same shape pattern as ``google/gemma-4-26B-A4B-it`` but small + enough to fit on any GPU.""" + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + global_head_dim=64, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + rope_parameters={ + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + }, + ) + cfg._attn_implementation = "sdpa" + return cfg + + +def _build_model(seed=0): + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + torch.manual_seed(seed) + cfg = _build_hybrid_config() + return Gemma4TextModel(cfg).cuda().to(torch.bfloat16).eval() + + +class TestFusedAttnSignature: + """The fused forward must accept the same call shape as + ``Gemma4TextDecoderLayer`` produces under the current transformers API + (no ``shared_kv_states`` kwarg).""" + + def test_decoder_layer_can_call_fused_forward(self, restore_gemma4_attention): + """Regression for the API drift: decoder layer calls + ``self.self_attn(hidden_states=..., position_embeddings=..., + attention_mask=..., position_ids=..., past_key_values=...)`` and + nothing else. A signature with a positional ``shared_kv_states`` + used to raise ``TypeError`` here before reaching the kernel.""" + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model() + ids = torch.randint(0, 128, (2, 16), device="cuda") + mask = torch.ones(2, 16, dtype=torch.long, device="cuda") + + patch_gemma4_fused_attn() + with torch.no_grad(): + out = m(input_ids=ids, attention_mask=mask).last_hidden_state + + assert out.shape == (2, 16, 64) + assert torch.isfinite(out).all() + + +class TestFusedAttnPerLayerCorrectness: + """Compare the patched attention layer to the stock implementation + on a single forward call. This isolates the fused kernel correctness + from cross-layer numerical drift.""" + + def _run_attention(self, model, layer_idx, hidden_states, position_ids): + """Call ``Gemma4TextAttention.forward`` (whatever is currently + installed) for one layer and return the output.""" + attn = model.layers[layer_idx].self_attn + layer_type = model.config.layer_types[layer_idx] + cos, sin = model.rotary_emb(hidden_states, position_ids, layer_type) + out, _ = attn( + hidden_states=hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + ) + return out + + @pytest.mark.parametrize( + "layer_idx", + [0, 1], + ids=["sliding_head32", "global_head64_proportional"], + ) + def test_forward_matches_stock(self, restore_gemma4_attention, layer_idx): + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=1) + hs = torch.randn(2, 16, 64, device="cuda", dtype=torch.bfloat16) + pos = torch.arange(16, device="cuda").unsqueeze(0).expand(2, -1) + + with torch.no_grad(): + ref = self._run_attention(m, layer_idx, hs, pos) + + patch_gemma4_fused_attn() + with torch.no_grad(): + got = self._run_attention(m, layer_idx, hs, pos) + + assert got.shape == ref.shape + assert torch.isfinite(got).all() + cos_sim = torch.nn.functional.cosine_similarity( + ref.flatten().float(), got.flatten().float(), dim=0 + ) + assert cos_sim > 0.999, ( + f"layer {layer_idx} fused vs stock cosine_sim={cos_sim:.6f}" + ) + # bf16 precision: a few millis of absolute drift per element is + # acceptable for a Q/K/V projection pipeline. Anything larger is + # a real bug. + torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-2) + + +class TestFusedAttnFullModel: + """End-to-end model forward + backward through both layer types.""" + + def test_full_forward_matches_stock(self, restore_gemma4_attention): + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=2) + ids = torch.randint(0, 128, (2, 32), device="cuda") + mask = torch.ones(2, 32, dtype=torch.long, device="cuda") + + with torch.no_grad(): + ref = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone() + + patch_gemma4_fused_attn() + with torch.no_grad(): + got = m(input_ids=ids, attention_mask=mask).last_hidden_state.clone() + + assert got.shape == ref.shape + assert torch.isfinite(got).all() + cos_sim = torch.nn.functional.cosine_similarity( + ref.flatten().float(), got.flatten().float(), dim=0 + ) + # End-to-end through 2 layers (RMSNorm, attention, MLP/MoE) in bf16 + # accumulates a small amount of numerical drift; we just want to + # pin that the two paths are computing the same function. + assert cos_sim > 0.999, f"end-to-end cosine_sim={cos_sim:.6f}" + + def test_backward_grad_flows_through_fused_path(self, restore_gemma4_attention): + """Gradients must propagate through the fused RMSNorm+RoPE kernels + for both the sliding and proportional-rope layers.""" + from axolotl.monkeypatch.models.gemma4.fused_attn import ( + patch_gemma4_fused_attn, + ) + + m = _build_model(seed=3).train() + patch_gemma4_fused_attn() + + ids = torch.randint(0, 128, (2, 16), device="cuda") + mask = torch.ones(2, 16, dtype=torch.long, device="cuda") + out = m(input_ids=ids, attention_mask=mask).last_hidden_state + out.sum().backward() + + # Both layers must accumulate gradients on q_norm.weight and + # k_norm.weight — that proves the fused kernel ran the backward. + for i, layer in enumerate(m.layers[:2]): + attn = layer.self_attn + assert attn.q_norm.weight.grad is not None, f"layer {i} q_norm no grad" + assert attn.k_norm.weight.grad is not None, f"layer {i} k_norm no grad" + assert attn.q_norm.weight.grad.isfinite().all() + assert attn.k_norm.weight.grad.isfinite().all() + assert attn.q_norm.weight.grad.abs().sum() > 0 + assert attn.k_norm.weight.grad.abs().sum() > 0 diff --git a/tests/monkeypatch/test_gemma4_hybrid_mask.py b/tests/monkeypatch/test_gemma4_hybrid_mask.py new file mode 100644 index 000000000..66d56bcf1 --- /dev/null +++ b/tests/monkeypatch/test_gemma4_hybrid_mask.py @@ -0,0 +1,343 @@ +"""Tests for the Gemma 4 hybrid-attention mask fix. + +These tests pin the single critical behavior: after installing the patch, +``modeling_gemma4.create_causal_mask`` passes an SDPA-overridden config to +the underlying mask builder regardless of what the caller's config says. +This is what keeps full-attention (head_dim=512) global layers from +crashing at long sequence lengths — they need a 4D SDPA-format mask, not +the 2D FA2 mask that would be built from the model-level config. + +The tests use a mocked ``create_causal_mask`` so they don't have to load +a real 26B Gemma 4 model or even have access to its weights. What matters +for the bug fix is which config is handed to the mask factory, not the +factory's actual output. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +pytest.importorskip( + "transformers.models.gemma4", + reason="gemma4_hybrid_mask patch only matters when Gemma 4 is available", +) + + +@pytest.fixture +def restore_gemma4_module(): + """Snapshot ``modeling_gemma4.create_causal_mask`` and restore after + each test so patch state doesn't leak across the suite.""" + from transformers.models.gemma4 import modeling_gemma4 + + saved = modeling_gemma4.create_causal_mask + yield modeling_gemma4 + modeling_gemma4.create_causal_mask = saved + # Reset the module-level flag so the next test can re-install cleanly. + from axolotl.monkeypatch import gemma4_hybrid_mask + + gemma4_hybrid_mask._PATCH_APPLIED = False + + +def test_patch_replaces_create_causal_mask(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + original = modeling_gemma4.create_causal_mask + assert patch_gemma4_hybrid_mask() is True + + assert modeling_gemma4.create_causal_mask is not original + assert modeling_gemma4.create_causal_mask._axolotl_original is original, ( + "patched wrapper must expose the original reference for teardown" + ) + + +def test_patch_is_idempotent(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + wrapper_first = modeling_gemma4.create_causal_mask + + # Second call must not re-wrap the already-wrapped function (which + # would leak the original reference through a chain of wrappers). + patch_gemma4_hybrid_mask() + wrapper_second = modeling_gemma4.create_causal_mask + + assert wrapper_first is wrapper_second + + +def test_patched_mask_forces_sdpa_config(restore_gemma4_module): + """Core invariant: when the patched wrapper is called with a config + that says ``flash_attention_2``, the underlying mask factory receives + a shallow-copied config whose ``_attn_implementation`` is ``"sdpa"``. + + Without this, the full-attention global layers get a 2D FA2 mask and + crash at long seq lens with the [B, H, S, S] / [B, S] expand error. + """ + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + # Swap in a mock BEFORE installing the patch so the wrapper captures + # it as the "original". The mock records every call so we can inspect + # what config got passed through. + mock_factory = MagicMock(name="create_causal_mask", return_value="mask_4d") + modeling_gemma4.create_causal_mask = mock_factory + patch_gemma4_hybrid_mask() + + # Caller-supplied config says FA2 (that's the model-level setting). + caller_config = SimpleNamespace( + _attn_implementation="flash_attention_2", + head_dim=512, + some_other_attr="preserved", + ) + result = modeling_gemma4.create_causal_mask( + caller_config, + inputs_embeds=None, + attention_mask=None, + past_key_values=None, + position_ids=None, + ) + + # Wrapper returned whatever the mock returned — no transformation of + # the result itself. + assert result == "mask_4d" + + # The mock was called exactly once with a config whose + # ``_attn_implementation`` is sdpa, NOT the caller's fa2. + assert mock_factory.call_count == 1 + (passed_config, *_), passed_kwargs = mock_factory.call_args + assert passed_config._attn_implementation == "sdpa" + + # The wrapper must NOT mutate the caller's config in place — other + # mask builders (e.g. create_sliding_window_causal_mask) read from + # the same config and must still see fa2. + assert caller_config._attn_implementation == "flash_attention_2" + + # Other attributes on the config must be preserved so the underlying + # factory has everything it needs (head_dim, rope_theta, vocab_size, ...). + assert passed_config.head_dim == 512 + assert passed_config.some_other_attr == "preserved" + + +def test_patched_wrapper_passes_through_all_kwargs(restore_gemma4_module): + """The wrapper must forward positional + keyword args to the original + unchanged, so transformers' own call-site in Gemma4TextModel.forward + keeps working across minor transformers-version signature drift.""" + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + mock_factory = MagicMock(return_value="mask") + modeling_gemma4.create_causal_mask = mock_factory + patch_gemma4_hybrid_mask() + + caller_config = SimpleNamespace(_attn_implementation="flash_attention_2") + modeling_gemma4.create_causal_mask( + caller_config, + "positional_arg", + inputs_embeds="embeds", + attention_mask="mask_2d", + past_key_values="cache", + position_ids="positions", + or_mask_function="or_fn", + ) + + args, kwargs = mock_factory.call_args + # First positional (after config override) is preserved. + assert args[1] == "positional_arg" + # All kwargs are forwarded untouched. + assert kwargs["inputs_embeds"] == "embeds" + assert kwargs["attention_mask"] == "mask_2d" + assert kwargs["past_key_values"] == "cache" + assert kwargs["position_ids"] == "positions" + assert kwargs["or_mask_function"] == "or_fn" + + +def test_unpatch_restores_original(restore_gemma4_module): + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import ( + patch_gemma4_hybrid_mask, + unpatch_gemma4_hybrid_mask, + ) + + sentinel = MagicMock(name="original") + modeling_gemma4.create_causal_mask = sentinel + patch_gemma4_hybrid_mask() + assert modeling_gemma4.create_causal_mask is not sentinel + + unpatch_gemma4_hybrid_mask() + assert modeling_gemma4.create_causal_mask is sentinel + + +def test_unpatch_is_safe_without_prior_patch(restore_gemma4_module): + from axolotl.monkeypatch.gemma4_hybrid_mask import unpatch_gemma4_hybrid_mask + + # Should be a no-op, no exception. + unpatch_gemma4_hybrid_mask() + + +def test_sliding_window_mask_builder_is_not_patched(restore_gemma4_module): + """Only ``create_causal_mask`` is overridden — the sliding-window + factory must remain bound to its original to preserve FA2 masks for + the sliding-attention layers. If we accidentally patch both, the + sliding layers get SDPA format and lose the FA2 speedup.""" + modeling_gemma4 = restore_gemma4_module + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + if not hasattr(modeling_gemma4, "create_sliding_window_causal_mask"): + pytest.skip("transformers version has no create_sliding_window_causal_mask") + + sliding_before = modeling_gemma4.create_sliding_window_causal_mask + patch_gemma4_hybrid_mask() + sliding_after = modeling_gemma4.create_sliding_window_causal_mask + assert sliding_after is sliding_before + + +# --------------------------------------------------------------------------- +# Integration tests with a tiny randomly-initialized Gemma4TextModel. +# +# These do NOT load real 26B weights. They build a ~350k-param Gemma 4 text +# model with 2 layers (one sliding, one full_attention), apply the hybrid +# attention path end-to-end, and run a forward pass with a padded +# attention_mask at a long-ish seq len. The invariant we're pinning is that +# the full_attention layer does not crash with the +# "Target sizes: [B, H, S, S]. Tensor sizes: [B, S]" +# error — the exact failure that blew up the Gemma 4 MoE 26B pilot at ~7k +# tokens in the FSDP2 training run. +# --------------------------------------------------------------------------- + + +def _build_tiny_gemma4_text_model(): + """Return a tiny randomly-initialized Gemma4TextModel with mixed layers.""" + import torch + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + from transformers.models.gemma4.modeling_gemma4 import Gemma4TextModel + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + ) + # Caller-supplied attn impl simulates the pilot config (fa2 at model + # level). The hybrid patch is what makes this survive long context. + cfg._attn_implementation = "sdpa" # start safe; the test toggles fa2 later + torch.manual_seed(42) + model = Gemma4TextModel(cfg).eval() + return model, cfg + + +def _apply_hybrid_attn_inline(model, cfg): + """Replicate what ``patch_manager._apply_gemma_hybrid_attention`` does + to a model, without needing a full PatchManager / pydantic cfg.""" + import copy + + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + for layer_idx, layer in enumerate(model.layers): + if cfg.layer_types[layer_idx] != "sliding_attention": + attn = getattr(layer, "self_attn", None) + if attn is not None and hasattr(attn, "config"): + sdpa_cfg = copy.copy(attn.config) + sdpa_cfg._attn_implementation = "sdpa" + attn.config = sdpa_cfg + patch_gemma4_hybrid_mask() + + +def test_tiny_gemma4_long_context_forward_does_not_crash(restore_gemma4_module): + """End-to-end invariant: with the hybrid attn patch applied, a tiny + Gemma4TextModel runs a forward at long context (1024 tokens) with + real padding in the attention mask, producing the expected output + shape. This exercises the actual code path that crashed the pilot + without needing a real 26B checkpoint or CUDA.""" + import torch + + model, cfg = _build_tiny_gemma4_text_model() + _apply_hybrid_attn_inline(model, cfg) + + B, S = 2, 1024 + input_ids = torch.randint(0, cfg.vocab_size, (B, S)) + attn_mask = torch.ones(B, S, dtype=torch.long) + # Pad positions in the second row. Without padding, SDPA falls back to + # ``is_causal=True`` with ``mask=None`` — we need a materialized 4D + # mask to exercise the actual bug site. + attn_mask[1, S // 2 :] = 0 + + with torch.no_grad(): + out = model(input_ids=input_ids, attention_mask=attn_mask) + + assert out.last_hidden_state.shape == (B, S, cfg.hidden_size) + assert torch.isfinite(out.last_hidden_state).all() + + +def test_patched_create_causal_mask_returns_4d_for_real_config( + restore_gemma4_module, +): + """Hit the REAL ``create_causal_mask`` (not a mock) via the wrapper + and verify the returned mask is a 4D tensor — which is the shape the + SDPA-patched global layers need. Without the patch and with a + caller-supplied FA2 config this would return a 2D mask and the layer + would crash at long context.""" + import torch + from transformers.cache_utils import DynamicCache + from transformers.models.gemma4.configuration_gemma4 import Gemma4TextConfig + + from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask + + patch_gemma4_hybrid_mask() + modeling_gemma4 = restore_gemma4_module + + cfg = Gemma4TextConfig( + vocab_size=128, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + head_dim=32, + layer_types=["sliding_attention", "full_attention"], + sliding_window=64, + max_position_embeddings=2048, + hidden_size_per_layer_input=16, + vocab_size_per_layer_input=128, + ) + # Simulate the pilot: caller says flash_attention_2, but global layers + # were switched to SDPA per-layer. Without the patch, create_causal_mask + # would return an FA2 2D mask here and the SDPA layer would crash. + cfg._attn_implementation = "flash_attention_2" + + B, S = 2, 1024 + inputs_embeds = torch.zeros((B, S, cfg.hidden_size), dtype=torch.float32) + attention_mask = torch.ones((B, S), dtype=torch.long) + attention_mask[1, S // 2 :] = 0 # force the 4D materialized path + position_ids = torch.arange(S).unsqueeze(0).expand(B, -1) + past_key_values = DynamicCache(config=cfg) + + mask = modeling_gemma4.create_causal_mask( + config=cfg, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + assert mask is not None + assert isinstance(mask, torch.Tensor) + assert mask.dim() == 4, ( + f"expected a 4D SDPA-format mask, got {mask.dim()}D " + f"shape={tuple(mask.shape)}. The full_attention global layers need " + "this shape or they crash at long context." + ) + assert mask.shape[0] == B + assert mask.shape[-1] == S + assert mask.shape[-2] == S + + # Caller's config must be untouched — other code paths still read it. + assert cfg._attn_implementation == "flash_attention_2"