[gemma4] fix fused RMSNorm+RoPE on hybrid attention models
- Kernel: fused_rms_norm_rope crashed when cos.shape[-1] < x.shape[-1]. Triton forward/backward take an n_rot runtime arg that restricts rotate_half to [0, n_rot) and treats trailing cols as RMSNorm-only pass-through (cos=1, sin=0 defaults). Wrapper also expands cos/sin that broadcast over batch. - Forward: _make_fused_forward used a stale shared_kv_states kwarg the current decoder layer no longer passes. Now mirrors stock attention, reading/writing past_key_values.shared_layers.
This commit is contained in:
@@ -242,6 +242,87 @@ class ProducerConfig:
|
||||
)
|
||||
|
||||
|
||||
class _GroupShardedSampler:
|
||||
"""Rank-aware shard of a ``RepeatSampler`` that preserves GRPO groups.
|
||||
|
||||
``RepeatSampler`` yields ``num_generations`` consecutive copies of
|
||||
each prompt, forming a GRPO group. For distributed training each
|
||||
rank must see a disjoint slice of prompts (otherwise every rank
|
||||
dogpiles on the first 1/world_size of the batch) while keeping each
|
||||
group intact on a single rank so advantage normalization sees all
|
||||
peer generations.
|
||||
|
||||
``accelerator.prepare(DataLoader)`` does not handle this correctly
|
||||
for custom samplers with ``split_batches=False`` (the default): it
|
||||
leaves the sampler alone and every rank replays identical indices.
|
||||
This wrapper fixes that by consuming the inner sampler's full
|
||||
output, chunking it into ``num_generations``-sized groups, and
|
||||
round-robining whole groups across ranks.
|
||||
|
||||
Intended to be used ONLY when distributed training is active
|
||||
(``num_replicas > 1``); for single-rank it is a no-op but still
|
||||
correct.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: Any,
|
||||
num_generations: int,
|
||||
rank: int,
|
||||
num_replicas: int,
|
||||
):
|
||||
if num_generations < 1:
|
||||
raise ValueError(f"num_generations must be >= 1, got {num_generations}")
|
||||
if num_replicas < 1:
|
||||
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
|
||||
if not (0 <= rank < num_replicas):
|
||||
raise ValueError(
|
||||
f"rank must be in [0, {num_replicas}), got {rank}"
|
||||
)
|
||||
self.inner = inner
|
||||
self.num_generations = num_generations
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
all_indices = list(self.inner)
|
||||
if len(all_indices) % self.num_generations != 0:
|
||||
raise ValueError(
|
||||
f"inner sampler yielded {len(all_indices)} indices, "
|
||||
f"not a multiple of num_generations={self.num_generations}"
|
||||
)
|
||||
# Chunk the flat index sequence into groups of num_generations
|
||||
# consecutive indices. ``RepeatSampler`` guarantees that each
|
||||
# group contains num_generations copies of the same prompt id.
|
||||
groups = [
|
||||
all_indices[i : i + self.num_generations]
|
||||
for i in range(0, len(all_indices), self.num_generations)
|
||||
]
|
||||
# Round-robin whole groups across ranks. Round-robin (vs.
|
||||
# contiguous chunking) preserves approximate shuffled order on
|
||||
# each rank even when the group count is small relative to the
|
||||
# world size.
|
||||
for group in groups[self.rank :: self.num_replicas]:
|
||||
yield from group
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
inner_len = len(self.inner)
|
||||
except TypeError:
|
||||
# Non-sized inner sampler — we can't know the per-rank
|
||||
# length without materializing. Return 0 as a hint that the
|
||||
# DataLoader should fall back to iteration.
|
||||
return 0
|
||||
total_groups = inner_len // self.num_generations
|
||||
# Ceiling division for the trailing groups that don't divide
|
||||
# evenly — extra groups go to the first ``total_groups %
|
||||
# num_replicas`` ranks, matching the round-robin above.
|
||||
my_groups = (
|
||||
total_groups + self.num_replicas - self.rank - 1
|
||||
) // self.num_replicas
|
||||
return my_groups * self.num_generations
|
||||
|
||||
|
||||
class DataProducer(ABC):
|
||||
"""Abstract base class for online data producers.
|
||||
|
||||
@@ -556,6 +637,34 @@ class GRPODataProducer(BaseDataProducer):
|
||||
seed=self._seed,
|
||||
)
|
||||
|
||||
# Shard the sampler across distributed ranks so each rank sees
|
||||
# a disjoint slice of prompts. ``RepeatSampler`` groups each
|
||||
# prompt with ``num_generations`` consecutive copies — our
|
||||
# wrapper round-robins WHOLE groups across ranks so all
|
||||
# generations of a given prompt stay on the same rank (needed
|
||||
# for GRPO advantage normalization within a group).
|
||||
#
|
||||
# Without this, ``accelerator.prepare(dl)`` with the default
|
||||
# ``split_batches=False`` leaves the custom sampler alone, so
|
||||
# every rank iterates the identical index sequence and the
|
||||
# cluster dogpiles on the first 1/world_size of the prompts.
|
||||
num_replicas = max(1, trainer.accelerator.num_processes)
|
||||
if num_replicas > 1:
|
||||
sampler = _GroupShardedSampler(
|
||||
inner=sampler,
|
||||
num_generations=self._num_generations,
|
||||
rank=trainer.accelerator.process_index,
|
||||
num_replicas=num_replicas,
|
||||
)
|
||||
logger.info(
|
||||
"[RANK:%d] _GroupShardedSampler active "
|
||||
"(num_replicas=%d, num_generations=%d, gen_batch=%d)",
|
||||
trainer.accelerator.process_index,
|
||||
num_replicas,
|
||||
self._num_generations,
|
||||
self._generation_batch_size,
|
||||
)
|
||||
|
||||
# Use identity collator (same as stock GRPOTrainer)
|
||||
def _identity(x):
|
||||
return x
|
||||
@@ -574,12 +683,11 @@ class GRPODataProducer(BaseDataProducer):
|
||||
rank=trainer.args.process_index,
|
||||
),
|
||||
)
|
||||
self._prompt_dl = trainer.accelerator.prepare(dl)
|
||||
|
||||
# Don't let accelerator track this dataloader
|
||||
acc_dls = trainer.accelerator._dataloaders
|
||||
if self._prompt_dl in acc_dls:
|
||||
acc_dls.remove(self._prompt_dl)
|
||||
# Skip accelerator.prepare — we're handling per-rank sharding
|
||||
# ourselves via ``_GroupShardedSampler``. ``prepare()`` would
|
||||
# otherwise try to wrap the DataLoader with its own sharding
|
||||
# logic which does not understand our group structure.
|
||||
self._prompt_dl = dl
|
||||
|
||||
self._prompt_iter = iter(self._prompt_dl)
|
||||
|
||||
|
||||
@@ -110,11 +110,35 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
item["agent_ref"] = full_item["agent_ref"]
|
||||
dataset_items.append(item)
|
||||
|
||||
# Expand by num_generations (agent produces one rollout per call)
|
||||
expanded_items = []
|
||||
for item in dataset_items:
|
||||
for _ in range(self._num_generations):
|
||||
expanded_items.append(item)
|
||||
# NOTE: do NOT re-expand by num_generations here.
|
||||
# ``RepeatSampler(mini_repeat_count=num_generations)`` already
|
||||
# yields ``num_generations`` consecutive copies of each unique
|
||||
# prompt, so ``inputs`` is a list of ``(unique_prompts_per_rank *
|
||||
# num_generations)`` items — one entry per rollout. Expanding
|
||||
# again here would fire ``num_generations^2`` rollouts per
|
||||
# prompt per rank and make every step dogpile on a handful of
|
||||
# tasks.
|
||||
expanded_items = dataset_items
|
||||
|
||||
# Diagnostic: log what this rank is about to fire.
|
||||
try:
|
||||
import collections
|
||||
iid_counts = collections.Counter()
|
||||
for it in dataset_items:
|
||||
iid_counts[
|
||||
(it.get("responses_create_params", {}).get("metadata") or {}).get(
|
||||
"instance_id"
|
||||
)
|
||||
] += 1
|
||||
LOG.info(
|
||||
"[RANK:%d] produce(): firing %d agent /run calls covering %d unique prompts: %s",
|
||||
trainer.accelerator.process_index,
|
||||
len(dataset_items),
|
||||
len(iid_counts),
|
||||
list(iid_counts.most_common(5)),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Call NeMo Gym agents
|
||||
loop = asyncio.new_event_loop()
|
||||
@@ -140,6 +164,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
logprobs_list = []
|
||||
rewards_list = []
|
||||
|
||||
num_turns_list: list[int] = []
|
||||
for resp in responses:
|
||||
parsed = _parse_agent_response(resp, eos_token_id)
|
||||
prompt_ids_list.append(parsed["prompt_ids"])
|
||||
@@ -147,6 +172,7 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
env_mask_list.append(parsed["env_mask"])
|
||||
logprobs_list.append(parsed["logprobs"])
|
||||
rewards_list.append(parsed["reward"])
|
||||
num_turns_list.append(parsed.get("num_turns", 0))
|
||||
|
||||
# Pad to tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@@ -179,22 +205,49 @@ class NemoGymDataProducer(GRPODataProducer):
|
||||
tool_mask = [torch.tensor(m, device=device) for m in env_mask_list]
|
||||
tool_mask = pad(tool_mask, padding_value=1, padding_side="right")
|
||||
|
||||
# Inject rewards into inputs so _compute_deferred_scores can use them
|
||||
# The deferred scoring path calls _calculate_rewards which reads reward_funcs.
|
||||
# Our passthrough reward_fn reads "env_reward" from kwargs.
|
||||
# Inject per-rollout reward + num_turns into each input. Since
|
||||
# ``RepeatSampler`` already yields ``num_generations`` copies of
|
||||
# each prompt, ``inputs`` has ONE entry per rollout (matching
|
||||
# ``rewards_list`` 1:1). No per-prompt grouping happens here —
|
||||
# GRPO advantage normalization is the trainer's job downstream.
|
||||
assert len(inputs) == len(rewards_list), (
|
||||
f"rewards/inputs length mismatch: "
|
||||
f"{len(rewards_list)} rewards vs {len(inputs)} inputs"
|
||||
)
|
||||
for i, inp in enumerate(inputs):
|
||||
# Each input gets rewards for its num_generations rollouts
|
||||
start = i * self._num_generations
|
||||
end = start + self._num_generations
|
||||
inp["env_reward"] = rewards_list[start:end]
|
||||
inp["env_reward"] = rewards_list[i]
|
||||
inp["num_turns"] = num_turns_list[i]
|
||||
|
||||
# One expanded_input per rollout (already correct count because
|
||||
# inputs has num_generations copies baked in by the sampler).
|
||||
expanded_inputs = [dict(inp) for inp in inputs]
|
||||
|
||||
# Log rollout-level stats to wandb from rank 0. These are the
|
||||
# true agent-side metrics (not the tokenized TRL view) — so
|
||||
# num_turns reflects how many /run iterations each rollout
|
||||
# actually took before finishing or hitting max_turns.
|
||||
if is_main and num_turns_list:
|
||||
try:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
import statistics as _stats
|
||||
|
||||
nonzero = sum(1 for r in rewards_list if r > 0)
|
||||
log_payload = {
|
||||
"rollout/num_turns/mean": float(_stats.mean(num_turns_list)),
|
||||
"rollout/num_turns/min": float(min(num_turns_list)),
|
||||
"rollout/num_turns/max": float(max(num_turns_list)),
|
||||
"rollout/reward/mean": float(_stats.mean(rewards_list)),
|
||||
"rollout/reward/nonzero_frac": (
|
||||
nonzero / len(rewards_list) if rewards_list else 0.0
|
||||
),
|
||||
"rollout/n_samples": float(len(rewards_list)),
|
||||
}
|
||||
wandb.log(log_payload, commit=False)
|
||||
except Exception as exc: # never let metric logging break training
|
||||
LOG.warning("rollout wandb log failed: %s", exc)
|
||||
|
||||
# Expand inputs to match expanded rollouts (num_generations copies)
|
||||
expanded_inputs = []
|
||||
for inp in inputs:
|
||||
for g in range(self._num_generations):
|
||||
expanded_inp = dict(inp)
|
||||
expanded_inp["env_reward"] = inp["env_reward"][g]
|
||||
expanded_inputs.append(expanded_inp)
|
||||
|
||||
# Decode completions for reward functions
|
||||
completions = trainer.processing_class.batch_decode(
|
||||
|
||||
@@ -53,6 +53,7 @@ def _rms_norm_rope_forward_kernel(
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
@@ -60,28 +61,35 @@ def _rms_norm_rope_forward_kernel(
|
||||
):
|
||||
"""
|
||||
Fused forward:
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm)
|
||||
y = x_norm * cos + rotate_half(x_norm) * sin (RoPE)
|
||||
x_norm = x / rms(x) [* weight] (RMSNorm, full n_cols)
|
||||
y[..., :n_rot] = rope(x_norm[..., :n_rot])
|
||||
y[..., n_rot:] = x_norm[..., n_rot:] (pass-through for partial rotary)
|
||||
|
||||
rotate_half swaps first/second halves and negates the first:
|
||||
rotate_half([a, b]) = [-b, a]
|
||||
rotate_half swaps first/second halves and negates the first, restricted
|
||||
to the rotary span [0, n_rot):
|
||||
rotate_half([a, b]) = [-b, a] where len(a) = len(b) = n_rot/2
|
||||
|
||||
For the partial-rotary pass-through region we load cos with default 1.0
|
||||
and sin with default 0.0 outside [0, n_rot), so the same formula
|
||||
`Y = X_norm * cos + X_rot_norm * sin` collapses to `Y = X_norm`.
|
||||
|
||||
cos/sin are indexed by row_idx // n_heads to handle per-head broadcast
|
||||
(cos/sin have shape (B*S, D) while X has shape (B*S*H, D)).
|
||||
(cos/sin have shape (B*S, n_rot) while X has shape (B*S*H, n_cols)).
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, D)
|
||||
# cos/sin row: divide by n_heads since cos/sin are (B*S, n_rot)
|
||||
cs_row_idx = row_idx // n_heads
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
# Load input row
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
X_dtype = X_row.dtype
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
|
||||
# RMSNorm: compute 1/rms
|
||||
# RMSNorm: compute 1/rms over the full row (rotary + pass-through)
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
@@ -94,33 +102,38 @@ def _rms_norm_rope_forward_kernel(
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0).to(tl.float32)
|
||||
X_norm = X_norm * W_row
|
||||
|
||||
# RoPE: load cos/sin (broadcast across heads)
|
||||
# RoPE: load cos/sin (broadcast across heads). For col >= n_rot we get
|
||||
# cos=1, sin=0 so the formula leaves X_norm untouched.
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
sin_row = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets, mask=mask, other=0
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# rotate_half: for col < half_dim, take -X_norm[col + half_dim]
|
||||
# for col >= half_dim, take X_norm[col - half_dim]
|
||||
# rotate_half within [0, n_rot):
|
||||
# for col < half_rot: take -X_norm[col + half_rot]
|
||||
# for col in [half_rot, n_rot): take X_norm[col - half_rot]
|
||||
# For col >= n_rot the rotation is irrelevant (sin = 0 zeros it out).
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
X_rot = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_mask & mask, other=0
|
||||
X_ptr + row_idx * X_row_stride + rot_offsets, mask=rot_load_mask, other=0
|
||||
).to(tl.float32)
|
||||
# Re-normalize the rotated values
|
||||
X_rot_norm = X_rot * rstd
|
||||
if HAS_WEIGHT:
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_mask & mask, other=0).to(
|
||||
tl.float32
|
||||
)
|
||||
W_rot = tl.load(W_ptr + rot_offsets, mask=rot_load_mask, other=0).to(tl.float32)
|
||||
X_rot_norm = X_rot_norm * W_rot
|
||||
|
||||
# Negate the first half (rotate_half negates x2, which becomes the first half)
|
||||
sign = tl.where(col_offsets < half_dim, -1.0, 1.0)
|
||||
sign = tl.where(col_offsets < half_rot, -1.0, 1.0)
|
||||
X_rot_norm = X_rot_norm * sign
|
||||
|
||||
# Final RoPE: y = x_norm * cos + rotate_half(x_norm) * sin
|
||||
@@ -153,13 +166,21 @@ def _rms_norm_rope_backward_kernel(
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = RoPE(RMSNorm(X, W))
|
||||
Backward for Y = RoPE(RMSNorm(X, W)) with optional partial rotary
|
||||
(`n_rot <= n_cols`).
|
||||
|
||||
For col < n_rot the standard RoPE adjoint applies. For col >= n_rot the
|
||||
output is just the normalized row, so dN[col] = dY[col] (achieved by
|
||||
loading cos with default 1.0 and forcing the rotate-half contribution
|
||||
to zero outside the rotary span).
|
||||
|
||||
cos/sin indexed by row_idx // n_heads for per-head broadcast.
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
@@ -167,7 +188,8 @@ def _rms_norm_rope_backward_kernel(
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
half_dim = n_cols // 2
|
||||
rot_mask_col = col_offsets < n_rot
|
||||
half_rot = n_rot // 2
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
@@ -186,33 +208,37 @@ def _rms_norm_rope_backward_kernel(
|
||||
rstd = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
cos_row = tl.load(
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets, mask=mask, other=0
|
||||
COS_ptr + cs_row_idx * COS_row_stride + col_offsets,
|
||||
mask=rot_mask_col,
|
||||
other=1.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# dN = dY * cos + rotate_half^T(dY * sin)
|
||||
# dN = dY * cos + rotate_half^T(dY * sin) (within the rotary span)
|
||||
# rotate_half^T([a, b]) = [b, -a] (adjoint of rotate_half)
|
||||
#
|
||||
# Compute rotate_half_transpose(dY * sin) by loading dY and sin at
|
||||
# rotated offsets directly: dY[rot] * sin[rot] * adj_sign
|
||||
# This is equivalent to rotating (dY * sin) because the rotation
|
||||
# just permutes which elements are multiplied.
|
||||
# For col >= n_rot the formula must collapse to dN = dY (since the
|
||||
# forward is just a pass-through). cos defaults to 1.0 above; the
|
||||
# rotate-half contribution is masked to zero below.
|
||||
rot_offsets = tl.where(
|
||||
col_offsets < half_dim, col_offsets + half_dim, col_offsets - half_dim
|
||||
col_offsets < half_rot, col_offsets + half_rot, col_offsets - half_rot
|
||||
)
|
||||
rot_mask = rot_offsets < n_cols
|
||||
rot_load_mask = (rot_offsets < n_cols) & rot_mask_col
|
||||
dY_rot = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
sin_rot = tl.load(
|
||||
SIN_ptr + cs_row_idx * SIN_row_stride + rot_offsets,
|
||||
mask=rot_mask & mask,
|
||||
mask=rot_load_mask,
|
||||
other=0,
|
||||
).to(tl.float32)
|
||||
|
||||
adj_sign = tl.where(col_offsets < half_dim, 1.0, -1.0)
|
||||
dN = dY_row * cos_row + dY_rot * sin_rot * adj_sign
|
||||
adj_sign = tl.where(col_offsets < half_rot, 1.0, -1.0)
|
||||
rotate_term = dY_rot * sin_rot * adj_sign
|
||||
# Zero out rotate-half contribution outside the rotary span.
|
||||
rotate_term = tl.where(rot_mask_col, rotate_term, 0.0)
|
||||
dN = dY_row * cos_row + rotate_term
|
||||
|
||||
# Pre-weight normalized: n = rstd * x
|
||||
n = X_row * rstd
|
||||
@@ -241,15 +267,17 @@ def _rms_norm_rope_backward_kernel(
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
Args:
|
||||
X: (B*S*H, head_dim) — contiguous, flattened from (B, S, H, D)
|
||||
W: (head_dim,) or None — RMSNorm weight
|
||||
cos: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, head_dim) — position embeddings (broadcast across heads)
|
||||
cos: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
sin: (B*S, n_rot) — position embeddings (broadcast across heads)
|
||||
eps: float
|
||||
n_heads: int — number of attention heads (for cos/sin indexing)
|
||||
n_rot: int — rotary dim (== head_dim for full rotary, < head_dim for
|
||||
partial rotary). Must be even and ``<= head_dim``.
|
||||
Returns:
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps
|
||||
"""
|
||||
@@ -273,6 +301,7 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
eps,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -282,7 +311,9 @@ def rms_norm_rope_forward(X, W, cos, sin, eps, n_heads):
|
||||
return Y, X, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_warps):
|
||||
def rms_norm_rope_backward(
|
||||
dY, X, W, cos, sin, RSTD, n_heads, n_rot, BLOCK_SIZE, num_warps
|
||||
):
|
||||
n_rows, n_cols = dY.shape
|
||||
has_weight = W is not None
|
||||
|
||||
@@ -315,6 +346,7 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
n_rot,
|
||||
n_heads,
|
||||
rows_per_program,
|
||||
HAS_WEIGHT=has_weight,
|
||||
@@ -329,13 +361,14 @@ def rms_norm_rope_backward(dY, X, W, cos, sin, RSTD, n_heads, BLOCK_SIZE, num_wa
|
||||
class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads):
|
||||
def forward(ctx, X, W, cos, sin, eps, n_heads, n_rot):
|
||||
"""
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, head_dim) — broadcast across heads
|
||||
sin: (B*S, head_dim) — broadcast across heads
|
||||
X: (B*S*H, head_dim)
|
||||
W: (head_dim,) or None
|
||||
cos: (B*S, n_rot) — broadcast across heads
|
||||
sin: (B*S, n_rot) — broadcast across heads
|
||||
n_heads: int
|
||||
n_rot: int — rotary dim (<= head_dim)
|
||||
"""
|
||||
Y, X_saved, RSTD, BLOCK_SIZE, num_warps = rms_norm_rope_forward(
|
||||
X,
|
||||
@@ -344,11 +377,13 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
eps,
|
||||
n_heads,
|
||||
n_rot,
|
||||
)
|
||||
ctx.eps = eps
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.n_heads = n_heads
|
||||
ctx.n_rot = n_rot
|
||||
ctx.has_weight = W is not None
|
||||
ctx.save_for_backward(X_saved, W, cos, sin, RSTD)
|
||||
return Y
|
||||
@@ -365,21 +400,26 @@ class FusedRMSNormRoPEFunction(torch.autograd.Function):
|
||||
sin,
|
||||
RSTD,
|
||||
ctx.n_heads,
|
||||
ctx.n_rot,
|
||||
ctx.BLOCK_SIZE,
|
||||
ctx.num_warps,
|
||||
)
|
||||
return dX, dW, None, None, None, None
|
||||
return dX, dW, None, None, None, None, None
|
||||
|
||||
|
||||
def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
Apply fused RMSNorm + RoPE.
|
||||
Apply fused RMSNorm + (partial) RoPE.
|
||||
|
||||
Args:
|
||||
x: (batch, seq_len, num_heads, head_dim) — after projection + view
|
||||
weight: (head_dim,) — RMSNorm weight, or None for no-scale norm
|
||||
cos: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
sin: (batch, seq_len, head_dim) — from RotaryEmbedding
|
||||
cos: (batch, seq_len, n_rot) — from RotaryEmbedding. ``n_rot``
|
||||
must be even and ``<= head_dim``. When ``n_rot < head_dim``
|
||||
the trailing ``head_dim - n_rot`` columns are RMSNorm-only
|
||||
(partial-rotary pass-through), matching stock Gemma 4 with
|
||||
``partial_rotary_factor < 1.0``.
|
||||
sin: (batch, seq_len, n_rot) — same shape as ``cos``
|
||||
eps: float — RMSNorm epsilon
|
||||
|
||||
Returns:
|
||||
@@ -387,14 +427,38 @@ def fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6):
|
||||
"""
|
||||
shape = x.shape # (B, S, H, D)
|
||||
B, S, H, D = shape
|
||||
n_rot = cos.shape[-1]
|
||||
if sin.shape[-1] != n_rot:
|
||||
raise ValueError(
|
||||
f"cos and sin must have the same last dim, got cos={cos.shape[-1]} "
|
||||
f"sin={sin.shape[-1]}"
|
||||
)
|
||||
if n_rot > D:
|
||||
raise ValueError(f"rotary dim ({n_rot}) cannot exceed head_dim ({D})")
|
||||
if n_rot % 2 != 0:
|
||||
raise ValueError(f"rotary dim must be even, got {n_rot}")
|
||||
|
||||
# Flatten to 2D: (B*S*H, D)
|
||||
x_flat = x.reshape(-1, D).contiguous()
|
||||
# Flatten cos/sin to (B*S, D) — the kernel will handle per-head broadcast
|
||||
# by dividing the row_idx by H to get the cos/sin row
|
||||
cos_flat = cos.reshape(B * S, D).contiguous()
|
||||
sin_flat = sin.reshape(B * S, D).contiguous()
|
||||
# cos/sin may broadcast over the batch dim (e.g. (1, S, n_rot) when
|
||||
# all sequences share the same rotary positions). The kernel needs a
|
||||
# dense (B*S, n_rot) buffer so that row_idx // n_heads maps cleanly
|
||||
# onto a single (b, s) pair, so expand-then-contiguous to materialize
|
||||
# the per-batch broadcast. Expand is a no-op when B == cos.shape[0].
|
||||
if cos.shape[0] != B:
|
||||
if cos.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"cos/sin batch dim ({cos.shape[0]}) must be 1 or equal "
|
||||
f"to x batch dim ({B})"
|
||||
)
|
||||
cos = cos.expand(B, S, n_rot)
|
||||
sin = sin.expand(B, S, n_rot)
|
||||
cos_flat = cos.reshape(B * S, n_rot).contiguous()
|
||||
sin_flat = sin.reshape(B * S, n_rot).contiguous()
|
||||
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(x_flat, weight, cos_flat, sin_flat, eps, H)
|
||||
y_flat = FusedRMSNormRoPEFunction.apply(
|
||||
x_flat, weight, cos_flat, sin_flat, eps, H, n_rot
|
||||
)
|
||||
return y_flat.view(shape)
|
||||
|
||||
|
||||
|
||||
@@ -156,6 +156,14 @@ class PatchManager:
|
||||
# which would clobber any earlier fix.
|
||||
self._fix_nemotron_h_conversion_mapping()
|
||||
|
||||
# Gemma 4 hybrid attention runs here in post-build (NOT post-load):
|
||||
# the per-layer ``self_attn.config._attn_implementation="sdpa"``
|
||||
# override needs to walk the raw model tree, which is broken by
|
||||
# the post-load PEFT wrapping. The accompanying
|
||||
# ``patch_gemma4_hybrid_mask`` monkey-patch is module-level and
|
||||
# installation-time-independent, so both halves of the fix live
|
||||
# cleanly in the same call even though one is instance-scoped
|
||||
# and the other is module-scoped.
|
||||
self._apply_gemma_hybrid_attention(model)
|
||||
self._finalize_moe_expert_quantization(model)
|
||||
|
||||
@@ -173,12 +181,23 @@ class PatchManager:
|
||||
which exceeds flash attention's supported size. This patch loads the model
|
||||
with flash_attention_2 for the sliding window layers (head_dim=256), then
|
||||
gives each global layer a shallow-copied config with _attn_implementation="sdpa".
|
||||
|
||||
We also install :func:`axolotl.monkeypatch.gemma4_hybrid_mask.patch_gemma4_hybrid_mask`
|
||||
which fixes the corresponding mask construction inside
|
||||
``Gemma4TextModel.forward``. Without it, the per-layer SDPA config
|
||||
override is not enough — the forward still builds a 2D FA2-format mask
|
||||
at the model level and the SDPA layers crash at long context lengths
|
||||
with ``RuntimeError: The expanded size of the tensor ... must match``.
|
||||
"""
|
||||
if not self.cfg.gemma4_hybrid_attn_impl:
|
||||
return
|
||||
|
||||
import copy
|
||||
|
||||
from axolotl.monkeypatch.gemma4_hybrid_mask import patch_gemma4_hybrid_mask
|
||||
|
||||
patch_gemma4_hybrid_mask()
|
||||
|
||||
# Navigate to the module that has 'layers' - varies by model structure:
|
||||
# Gemma4ForConditionalGeneration -> .model (Gemma4Model) -> .language_model (Gemma4TextModel) -> .layers
|
||||
# Gemma4ForCausalLM -> .model (Gemma4TextModel) -> .layers
|
||||
@@ -392,6 +411,14 @@ class PatchManager:
|
||||
patch_qwen3_5_vlm_flash_attention()
|
||||
|
||||
if self.cfg.model_config_type in ("gemma4", "gemma4_text"):
|
||||
# The fused attn path is now compatible with
|
||||
# ``gemma4_hybrid_attn_impl``: the kernel handles partial
|
||||
# rotary (cos.shape[-1] < head_dim) and the fused forward
|
||||
# mirrors the current ``Gemma4TextAttention.forward`` API
|
||||
# for shared kv (read from / write to
|
||||
# ``past_key_values.shared_layers``). See
|
||||
# ``src/axolotl/kernels/GEMMA4_FUSED_ROPE_HYBRID_ATTN_BUG.md``
|
||||
# for the history.
|
||||
from axolotl.monkeypatch.models.gemma4.fused_attn import (
|
||||
patch_gemma4_fused_attn,
|
||||
)
|
||||
|
||||
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
115
src/axolotl/monkeypatch/gemma4_hybrid_mask.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Hybrid attention mask fix for Gemma 4.
|
||||
|
||||
Gemma 4 has full-attention (global) layers with ``head_dim=512`` which
|
||||
exceeds flash-attention-2's supported size. Axolotl's hybrid-attention
|
||||
patch in ``patch_manager._apply_gemma_hybrid_attention`` works around
|
||||
this by forcing ``_attn_implementation="sdpa"`` on each global layer's
|
||||
``self_attn.config``, leaving sliding-window layers on FA2.
|
||||
|
||||
The per-layer config override alone is insufficient, however:
|
||||
``Gemma4TextModel.forward`` builds a single ``causal_mask_mapping`` dict
|
||||
using the **model-level** config and passes the mapped mask to each
|
||||
decoder layer. With FA2 still set at the model level, the ``full_attention``
|
||||
entry in that mapping is a 2D mask (FA2 format), but SDPA needs a 4D mask.
|
||||
The global layers then fail with::
|
||||
|
||||
RuntimeError: The expanded size of the tensor (S) must match the existing
|
||||
size (B) at non-singleton dimension 2. Target sizes: [B, H, S, S]. Tensor
|
||||
sizes: [B, S]
|
||||
|
||||
...when the sequence length grows past roughly 7k tokens.
|
||||
|
||||
This module fixes the symptom by monkey-patching ``create_causal_mask`` in
|
||||
``transformers.models.gemma4.modeling_gemma4``'s module namespace — NOT
|
||||
the original in ``masking_utils``. The wrapper forces
|
||||
``_attn_implementation="sdpa"`` on a shallow-copied config before calling
|
||||
through, so the ``full_attention`` mask built inside ``Gemma4TextModel.forward``
|
||||
is always 4D/SDPA-compatible. ``create_sliding_window_causal_mask`` is left
|
||||
alone, so sliding-window layers continue to receive FA2-format masks.
|
||||
|
||||
The patch is idempotent. Install once per process, before any Gemma 4
|
||||
forward pass runs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
_PATCH_APPLIED = False
|
||||
|
||||
|
||||
def patch_gemma4_hybrid_mask() -> bool:
|
||||
"""Install the Gemma 4 hybrid-attention mask fix.
|
||||
|
||||
Returns ``True`` if the patch was installed (or was already installed),
|
||||
``False`` if the target module could not be imported (e.g. transformers
|
||||
version predates Gemma 4) — in which case nothing is done and the
|
||||
caller can continue unaffected.
|
||||
"""
|
||||
global _PATCH_APPLIED
|
||||
if _PATCH_APPLIED:
|
||||
return True
|
||||
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
LOG.debug(
|
||||
"gemma4_hybrid_mask: transformers.models.gemma4 not importable, "
|
||||
"skipping. This is fine for non-Gemma4 training."
|
||||
)
|
||||
return False
|
||||
|
||||
if not hasattr(modeling_gemma4, "create_causal_mask"):
|
||||
LOG.warning(
|
||||
"gemma4_hybrid_mask: modeling_gemma4 has no 'create_causal_mask' "
|
||||
"binding, skipping. Transformers API may have changed."
|
||||
)
|
||||
return False
|
||||
|
||||
original = modeling_gemma4.create_causal_mask
|
||||
|
||||
def hybrid_create_causal_mask(config: Any, *args: Any, **kwargs: Any):
|
||||
"""Wrapper that forces SDPA format for the full-attention mask.
|
||||
|
||||
The global layers were patched to SDPA by
|
||||
``_apply_gemma_hybrid_attention``, so their mask must be 4D. The
|
||||
original ``create_causal_mask`` dispatches on
|
||||
``config._attn_implementation``; we shadow that with a local
|
||||
override.
|
||||
"""
|
||||
sdpa_config = copy.copy(config)
|
||||
sdpa_config._attn_implementation = "sdpa"
|
||||
return original(sdpa_config, *args, **kwargs)
|
||||
|
||||
# Preserve the original reference on the wrapper for tests / teardown.
|
||||
hybrid_create_causal_mask._axolotl_original = original # type: ignore[attr-defined]
|
||||
|
||||
modeling_gemma4.create_causal_mask = hybrid_create_causal_mask
|
||||
_PATCH_APPLIED = True
|
||||
LOG.info(
|
||||
"gemma4_hybrid_mask: patched modeling_gemma4.create_causal_mask to "
|
||||
"force SDPA-format masks for full-attention layers"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def unpatch_gemma4_hybrid_mask() -> None:
|
||||
"""Restore the original ``create_causal_mask``. Useful for tests."""
|
||||
global _PATCH_APPLIED
|
||||
if not _PATCH_APPLIED:
|
||||
return
|
||||
try:
|
||||
from transformers.models.gemma4 import modeling_gemma4
|
||||
except ImportError:
|
||||
_PATCH_APPLIED = False
|
||||
return
|
||||
current = modeling_gemma4.create_causal_mask
|
||||
original = getattr(current, "_axolotl_original", None)
|
||||
if original is not None:
|
||||
modeling_gemma4.create_causal_mask = original
|
||||
_PATCH_APPLIED = False
|
||||
@@ -30,7 +30,6 @@ def _make_fused_forward(original_forward):
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
@@ -66,8 +65,14 @@ def _make_fused_forward(original_forward):
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
# ---- K/V path ----
|
||||
if self.is_kv_shared_layer:
|
||||
key_states, value_states = shared_kv_states[self.kv_shared_layer_index]
|
||||
# Current transformers stores shared kv on `past_key_values.shared_layers`
|
||||
# (the legacy `shared_kv_states` decoder kwarg was removed). We mirror
|
||||
# the stock attention forward exactly so the dispatch is identical
|
||||
# regardless of whether the model was patched.
|
||||
if self.is_kv_shared_layer and past_key_values is not None:
|
||||
key_states, value_states = past_key_values.shared_layers[
|
||||
self.kv_shared_layer_index
|
||||
]
|
||||
key_states = key_states.to(query_states.device)
|
||||
value_states = value_states.to(query_states.device)
|
||||
else:
|
||||
@@ -101,12 +106,18 @@ def _make_fused_forward(original_forward):
|
||||
value_states = fused_rms_norm_noscale(value_states, eps=eps)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
if past_key_values is not None and not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
shared_kv_states[self.layer_idx] = key_states, value_states
|
||||
if past_key_values is not None:
|
||||
if not self.is_kv_shared_layer:
|
||||
key_states, value_states = past_key_values.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
if self.store_full_length_kv:
|
||||
if not hasattr(past_key_values, "shared_layers"):
|
||||
past_key_values.shared_layers = {}
|
||||
past_key_values.shared_layers[self.layer_idx] = (
|
||||
key_states,
|
||||
value_states,
|
||||
)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
|
||||
Reference in New Issue
Block a user