Files
axolotl/tests/kernels/test_rms_norm_gated.py
Wing Lian a67392c427 liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]
* liger support for qwen 3.5 and fused rmsnorm+gated

* support for qwen 3.5 moe

* fix version ref

* fixups for PR code review
2026-03-22 13:19:21 -04:00

230 lines
8.0 KiB
Python

"""
Correctness tests for fused RMSNorm + SiLU Gate kernel.
Tests against the eager Qwen3_5RMSNormGated implementation.
"""
import pytest
import torch
import torch.nn.functional as F
pytest.importorskip("triton", reason="triton required for fused kernels")
if not torch.cuda.is_available():
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
class EagerRMSNormGated(torch.nn.Module):
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _sync_weights(eager_mod, fused_mod):
"""Copy weights from eager to fused module."""
fused_mod.weight.data.copy_(eager_mod.weight.data)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 128, 256),
(4, 64, 512),
(1, 32, 1024),
(2, 16, 2560), # Qwen3.5-4B hidden_size
(2, 16, 4096), # Qwen3.5-9B hidden_size
(1, 8, 5120), # Qwen3.5-27B hidden_size
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
],
)
class TestRMSNormGatedForward:
def test_output_matches_eager(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
if dtype == torch.float32:
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
else:
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
def test_output_shape(self, dtype, shape):
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
y = fused(X, gate=G)
assert y.shape == (B, T, H)
assert y.dtype == dtype
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 32, 256),
(2, 16, 512),
(2, 16, 2560), # Qwen3.5-4B
(1, 8, 4096), # Qwen3.5-9B
(1, 8, 5120), # Qwen3.5-27B
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
],
)
class TestRMSNormGatedBackward:
def test_grad_x(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
def test_grad_gate(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
def test_grad_weight(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
)
class TestRMSNormGatedEdgeCases:
def test_gate_none_raises(self):
fused = FusedRMSNormGated(256).cuda()
X = torch.randn(2, 4, 256, device="cuda")
with pytest.raises(ValueError, match="requires a gate tensor"):
fused(X, gate=None)
def test_2d_input(self):
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
torch.manual_seed(42)
H = 512
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
def test_random_weight_init(self):
"""Test with non-default weight values."""
torch.manual_seed(123)
H = 256
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
# Randomize weights
eager.weight.data = torch.randn_like(eager.weight.data)
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)