417 lines
16 KiB
Python
417 lines
16 KiB
Python
"""
|
|
Correctness tests for the fused RMSNorm+RoPE Triton kernel.
|
|
|
|
Tests forward and backward against the reference Gemma4 implementation
|
|
(Gemma4RMSNorm + apply_rotary_pos_emb) across both sliding window
|
|
(head_dim=256) and global attention (head_dim=512) layer configurations.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Skip entire module if no CUDA
|
|
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
|
|
|
|
|
|
def _reference_norm_rope(x, weight, cos, sin, eps):
|
|
"""Reference: separate Gemma4RMSNorm + apply_rotary_pos_emb."""
|
|
from transformers.models.gemma4.modeling_gemma4 import (
|
|
Gemma4RMSNorm,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
D = x.shape[-1]
|
|
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
|
norm.weight.data.copy_(weight)
|
|
normed = norm(x)
|
|
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
|
|
|
|
|
def _reference_norm_noscale(x, eps):
|
|
"""Reference: Gemma4RMSNorm with_scale=False."""
|
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
|
|
|
D = x.shape[-1]
|
|
norm = Gemma4RMSNorm(D, eps=eps, with_scale=False).to(x.device, x.dtype)
|
|
return norm(x)
|
|
|
|
|
|
def _reference_partial_norm_rope(x, weight, cos, sin, eps):
|
|
"""Reference: Gemma4RMSNorm over the full head_dim, then stock
|
|
``apply_rotary_pos_emb`` over the first ``cos.shape[-1]`` columns, with
|
|
the trailing columns passed through unchanged. Mirrors how Llama-style
|
|
partial rotary is layered on top of the stock RMSNorm + RoPE primitives.
|
|
"""
|
|
from transformers.models.gemma4.modeling_gemma4 import (
|
|
Gemma4RMSNorm,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
D = x.shape[-1]
|
|
n_rot = cos.shape[-1]
|
|
norm = Gemma4RMSNorm(D, eps=eps).to(x.device, x.dtype)
|
|
norm.weight.data.copy_(weight)
|
|
normed = norm(x)
|
|
if n_rot == D:
|
|
return apply_rotary_pos_emb(normed, cos, sin, unsqueeze_dim=2)
|
|
x_rot = normed[..., :n_rot]
|
|
x_pass = normed[..., n_rot:]
|
|
rotated = apply_rotary_pos_emb(x_rot, cos, sin, unsqueeze_dim=2)
|
|
return torch.cat([rotated, x_pass], dim=-1)
|
|
|
|
|
|
@pytest.fixture(
|
|
params=[
|
|
(2, 64, 32, 256), # sliding window layer shape
|
|
(2, 64, 4, 512), # global attention layer shape
|
|
(1, 128, 16, 256), # different batch/seq
|
|
(1, 1, 1, 8), # minimal size
|
|
],
|
|
ids=["sliding_256", "global_512", "varied", "minimal"],
|
|
)
|
|
def shapes(request):
|
|
return request.param
|
|
|
|
|
|
@pytest.fixture(params=[torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
|
|
def dtype(request):
|
|
return request.param
|
|
|
|
|
|
class TestFusedRMSNormRoPEForward:
|
|
"""Forward pass correctness."""
|
|
|
|
def test_matches_reference(self, shapes, dtype):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
B, S, H, D = shapes
|
|
eps = 1e-6
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
|
weight = torch.randn(D, device="cuda", dtype=dtype)
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=dtype)
|
|
|
|
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim > 0.999, f"Forward cosine_sim={cos_sim:.6f}, expected > 0.999"
|
|
|
|
def test_output_shape(self, shapes):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
B, S, H, D = shapes
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
y = fused_rms_norm_rope(x, weight, cos, sin, eps=1e-6)
|
|
assert y.shape == x.shape
|
|
assert y.dtype == x.dtype
|
|
|
|
|
|
class TestFusedRMSNormRoPEBackward:
|
|
"""Backward pass correctness via gradient comparison."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"B,S,H,D",
|
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
|
ids=["sliding_256", "global_512"],
|
|
)
|
|
def test_x_grad_matches_reference(self, B, S, H, D):
|
|
from transformers.models.gemma4.modeling_gemma4 import (
|
|
Gemma4RMSNorm,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
eps = 1e-6
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Reference backward
|
|
x_ref = torch.randn(
|
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
|
norm_ref.weight.data.copy_(weight_init)
|
|
y_ref = apply_rotary_pos_emb(norm_ref(x_ref), cos, sin, unsqueeze_dim=2)
|
|
y_ref.sum().backward()
|
|
|
|
# Fused backward
|
|
x_fused = x_ref.data.clone().requires_grad_(True)
|
|
w_fused = weight_init.clone().requires_grad_(True)
|
|
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
|
y_fused.sum().backward()
|
|
|
|
cos_sim_x = torch.nn.functional.cosine_similarity(
|
|
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim_x > 0.999, f"x grad cosine_sim={cos_sim_x:.6f}, expected > 0.999"
|
|
|
|
@pytest.mark.parametrize(
|
|
"B,S,H,D",
|
|
[(2, 64, 32, 256), (2, 64, 4, 512)],
|
|
ids=["sliding_256", "global_512"],
|
|
)
|
|
def test_weight_grad_matches_reference(self, B, S, H, D):
|
|
from transformers.models.gemma4.modeling_gemma4 import (
|
|
Gemma4RMSNorm,
|
|
apply_rotary_pos_emb,
|
|
)
|
|
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
eps = 1e-6
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Reference
|
|
x_ref = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
|
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
|
apply_rotary_pos_emb(
|
|
norm_ref(x_ref), cos, sin, unsqueeze_dim=2
|
|
).sum().backward()
|
|
|
|
# Fused
|
|
w_fused = weight_init.clone().requires_grad_(True)
|
|
fused_rms_norm_rope(x_ref.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
|
|
|
cos_sim_w = torch.nn.functional.cosine_similarity(
|
|
w_fused.grad.flatten().float(),
|
|
norm_ref.weight.grad.flatten().float(),
|
|
dim=0,
|
|
)
|
|
assert cos_sim_w > 0.995, (
|
|
f"weight grad cosine_sim={cos_sim_w:.6f}, expected > 0.995"
|
|
)
|
|
|
|
def test_grad_flows(self):
|
|
"""Verify gradients are non-zero and finite."""
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
B, S, H, D = 1, 16, 4, 64
|
|
x = torch.randn(
|
|
B, S, H, D, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
w = torch.randn(D, device="cuda", dtype=torch.bfloat16, requires_grad=True)
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
y = fused_rms_norm_rope(x, w, cos, sin, eps=1e-6)
|
|
y.sum().backward()
|
|
|
|
assert x.grad is not None, "x.grad is None"
|
|
assert w.grad is not None, "w.grad is None"
|
|
assert x.grad.isfinite().all(), "x.grad has non-finite values"
|
|
assert w.grad.isfinite().all(), "w.grad has non-finite values"
|
|
assert x.grad.abs().sum() > 0, "x.grad is all zeros"
|
|
assert w.grad.abs().sum() > 0, "w.grad is all zeros"
|
|
|
|
|
|
class TestFusedRMSNormRoPEPartialRotary:
|
|
"""Partial-rotary: cos/sin last dim is smaller than head_dim.
|
|
|
|
Compares against the original primitives (`Gemma4RMSNorm` +
|
|
`apply_rotary_pos_emb`) applied to the rotated slice with the trailing
|
|
columns passed through. Without the kernel fix this used to crash with
|
|
`RuntimeError: shape '[..., D]' is invalid for input of size B*S*n_rot`.
|
|
"""
|
|
|
|
@pytest.mark.parametrize(
|
|
"B,S,H,D,n_rot",
|
|
[
|
|
(2, 16, 4, 64, 32), # half rotary (Llama-style 0.5)
|
|
(2, 16, 4, 64, 16), # quarter rotary
|
|
(2, 32, 8, 128, 64), # half rotary, larger heads
|
|
(1, 8, 2, 256, 64), # 26B sliding-shape, 0.25 partial
|
|
(1, 8, 2, 64, 64), # n_rot == D: must still match full-rotary path
|
|
],
|
|
ids=["half_64", "quarter_64", "half_128", "quarter_256", "full_64"],
|
|
)
|
|
def test_forward_matches_reference(self, B, S, H, D, n_rot):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
eps = 1e-6
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
|
|
y_ref = _reference_partial_norm_rope(x.clone(), weight, cos, sin, eps)
|
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
|
|
|
assert y_fused.shape == y_ref.shape == (B, S, H, D)
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim > 0.999, (
|
|
f"partial rotary forward cosine_sim={cos_sim:.6f} "
|
|
f"(B={B},S={S},H={H},D={D},n_rot={n_rot})"
|
|
)
|
|
|
|
# The pass-through tail must equal the reference RMSNorm output bit-
|
|
# for-bit (any deviation would mean the kernel is touching it with a
|
|
# spurious rotation, which is the original bug class).
|
|
torch.testing.assert_close(
|
|
y_fused[..., n_rot:], y_ref[..., n_rot:], rtol=1e-2, atol=1e-2
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"B,S,H,D,n_rot",
|
|
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
|
ids=["half_64", "quarter_256"],
|
|
)
|
|
def test_x_grad_matches_reference(self, B, S, H, D, n_rot):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
eps = 1e-6
|
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Reference backward via the original primitives
|
|
x_ref = x_data.clone().requires_grad_(True)
|
|
w_ref = weight_init.clone()
|
|
y_ref = _reference_partial_norm_rope(x_ref, w_ref, cos, sin, eps)
|
|
y_ref.sum().backward()
|
|
|
|
# Fused backward
|
|
x_fused = x_data.clone().requires_grad_(True)
|
|
w_fused = weight_init.clone().requires_grad_(True)
|
|
y_fused = fused_rms_norm_rope(x_fused, w_fused, cos, sin, eps=eps)
|
|
y_fused.sum().backward()
|
|
|
|
cos_sim_x = torch.nn.functional.cosine_similarity(
|
|
x_fused.grad.flatten().float(), x_ref.grad.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim_x > 0.999, f"partial rotary x grad cosine_sim={cos_sim_x:.6f}"
|
|
|
|
@pytest.mark.parametrize(
|
|
"B,S,H,D,n_rot",
|
|
[(2, 16, 4, 64, 32), (1, 8, 2, 256, 64)],
|
|
ids=["half_64", "quarter_256"],
|
|
)
|
|
def test_weight_grad_matches_reference(self, B, S, H, D, n_rot):
|
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm
|
|
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
eps = 1e-6
|
|
cos = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, n_rot, device="cuda", dtype=torch.bfloat16)
|
|
weight_init = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
x_data = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# Reference: Gemma4RMSNorm whose .weight collects grads, then partial
|
|
# rotary applied to the rotated slice.
|
|
norm_ref = Gemma4RMSNorm(D, eps=eps).cuda().to(torch.bfloat16)
|
|
norm_ref.weight = torch.nn.Parameter(weight_init.clone())
|
|
normed = norm_ref(x_data)
|
|
from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb
|
|
|
|
rotated = apply_rotary_pos_emb(normed[..., :n_rot], cos, sin, unsqueeze_dim=2)
|
|
y_ref = torch.cat([rotated, normed[..., n_rot:]], dim=-1)
|
|
y_ref.sum().backward()
|
|
|
|
w_fused = weight_init.clone().requires_grad_(True)
|
|
fused_rms_norm_rope(x_data.clone(), w_fused, cos, sin, eps=eps).sum().backward()
|
|
|
|
cos_sim_w = torch.nn.functional.cosine_similarity(
|
|
w_fused.grad.flatten().float(),
|
|
norm_ref.weight.grad.flatten().float(),
|
|
dim=0,
|
|
)
|
|
assert cos_sim_w > 0.995, (
|
|
f"partial rotary weight grad cosine_sim={cos_sim_w:.6f}"
|
|
)
|
|
|
|
def test_full_rotary_unchanged_when_n_rot_equals_d(self):
|
|
"""Regression: passing cos/sin with shape == head_dim must still
|
|
match the full-rotary reference (the partial-rotary code path must
|
|
not perturb the existing full-rotary output)."""
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
B, S, H, D = 2, 16, 4, 64
|
|
eps = 1e-6
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
weight = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
cos = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
y_ref = _reference_norm_rope(x.clone(), weight, cos, sin, eps)
|
|
y_fused = fused_rms_norm_rope(x.clone(), weight, cos, sin, eps=eps)
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim > 0.999, f"full-rotary regression cos_sim={cos_sim:.6f}"
|
|
|
|
def test_validation_errors(self):
|
|
"""Wrapper rejects misshaped inputs cleanly (instead of a cryptic
|
|
Triton crash deeper in the kernel)."""
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_rope
|
|
|
|
B, S, H, D = 1, 4, 2, 64
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=torch.bfloat16)
|
|
w = torch.randn(D, device="cuda", dtype=torch.bfloat16)
|
|
|
|
# n_rot > head_dim
|
|
cos_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
|
sin_big = torch.randn(B, S, D + 16, device="cuda", dtype=torch.bfloat16)
|
|
with pytest.raises(ValueError, match="cannot exceed head_dim"):
|
|
fused_rms_norm_rope(x, w, cos_big, sin_big)
|
|
|
|
# cos/sin last-dim mismatch
|
|
cos = torch.randn(B, S, 32, device="cuda", dtype=torch.bfloat16)
|
|
sin = torch.randn(B, S, 16, device="cuda", dtype=torch.bfloat16)
|
|
with pytest.raises(ValueError, match="same last dim"):
|
|
fused_rms_norm_rope(x, w, cos, sin)
|
|
|
|
# odd rotary dim
|
|
cos_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
|
sin_odd = torch.randn(B, S, 31, device="cuda", dtype=torch.bfloat16)
|
|
with pytest.raises(ValueError, match="must be even"):
|
|
fused_rms_norm_rope(x, w, cos_odd, sin_odd)
|
|
|
|
|
|
class TestFusedRMSNormNoScale:
|
|
"""Tests for v_norm (RMSNorm without learnable scale)."""
|
|
|
|
def test_forward_matches_reference(self, shapes, dtype):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
|
|
|
B, S, H, D = shapes
|
|
eps = 1e-6
|
|
x = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
|
|
|
|
y_ref = _reference_norm_noscale(x.clone(), eps)
|
|
y_fused = fused_rms_norm_noscale(x.clone(), eps=eps)
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
y_ref.flatten().float(), y_fused.flatten().float(), dim=0
|
|
)
|
|
assert cos_sim > 0.999, f"v_norm cosine_sim={cos_sim:.6f}, expected > 0.999"
|
|
|
|
def test_backward_flows(self):
|
|
from axolotl.kernels.gemma4_fused_rope import fused_rms_norm_noscale
|
|
|
|
x = torch.randn(
|
|
1, 16, 4, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
|
)
|
|
y = fused_rms_norm_noscale(x, eps=1e-6)
|
|
y.sum().backward()
|
|
|
|
assert x.grad is not None
|
|
assert x.grad.isfinite().all()
|
|
assert x.grad.abs().sum() > 0
|