Files
axolotl/tests/e2e/integrations/test_sonicmoe.py
NanoCode012 6a8baf8fa7 feat: add sonicmoe (#3411)
* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
2026-03-05 13:43:31 -05:00

289 lines
10 KiB
Python

"""
End-to-end gradient and convergence tests for SonicMoE integration.
Requires:
- H100/H200 GPU (SonicMoE CUTLASS kernels target sm_90)
- sonicmoe package installed
- transformers with Qwen3MoE support
Usage:
pytest tests/e2e/integrations/test_sonicmoe.py -v -s
"""
import importlib.util
import math
import pytest
import torch
_sonicmoe_available = importlib.util.find_spec("sonicmoe") is not None
_is_hopper = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
pytestmark = [
pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA GPU"),
pytest.mark.skipif(
not _is_hopper, reason="SonicMoE CUTLASS kernels require Hopper (sm_90)"
),
pytest.mark.skipif(not _sonicmoe_available, reason="SonicMoE not installed"),
]
def _create_tiny_qwen3_config():
"""Create a minimal Qwen3MoE config for fast testing."""
from transformers import AutoConfig
config = AutoConfig.for_model("qwen3_moe")
config.hidden_size = 512
config.intermediate_size = 1024
config.moe_intermediate_size = 64
config.num_attention_heads = 16
config.num_key_value_heads = 2
config.head_dim = 32
config.num_hidden_layers = 2
config.num_experts = 8
config.num_experts_per_tok = 2
config.vocab_size = 1000
config.max_position_embeddings = 128
config.norm_topk_prob = True
config.torch_dtype = torch.bfloat16
return config
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
interleave_gate_up,
)
with torch.no_grad():
for name, param in model.named_parameters():
if "gate_up_proj" in name:
param.copy_(interleave_gate_up(param))
def _unpatch_sonicmoe():
"""Restore original forward on the MoE block class if it was patched."""
from axolotl.integrations.kernels.constants import resolve_moe_block_classes
for moe_cls in resolve_moe_block_classes("qwen3_moe"):
if hasattr(moe_cls, "_original_forward"):
moe_cls.forward = moe_cls._original_forward
del moe_cls._original_forward
class TestSonicMoEForwardCorrectness:
"""Verify SonicMoE-patched model produces same output as original."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# Original model
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
with torch.no_grad():
out_orig = model_orig(input_ids)
# Patched model (same weights, interleaved for SonicMoE)
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
with torch.no_grad():
out_patched = model_patched(input_ids)
max_diff = (out_orig.logits - out_patched.logits).abs().max().item()
assert torch.allclose(
out_orig.logits, out_patched.logits, atol=1e-1, rtol=1e-1
), f"Output mismatch: max diff={max_diff:.6f}"
class TestSonicMoEGradientCorrectness:
"""Compare gradients between original HuggingFace and SonicMoE-patched forward."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_gradients_match(self):
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
# ---------- Original model ----------
model_orig = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
out_orig = model_orig(input_ids, labels=input_ids)
out_orig.loss.backward()
grads_orig = {
n: p.grad.float().clone()
for n, p in model_orig.named_parameters()
if p.grad is not None
}
loss_orig = out_orig.loss.item()
# ---------- SonicMoE-patched model (same weights, interleaved) ----------
model_patched = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
model_patched.load_state_dict(model_orig.state_dict())
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model_patched)
out_patched = model_patched(input_ids, labels=input_ids)
out_patched.loss.backward()
grads_patched = {}
for n, p in model_patched.named_parameters():
if p.grad is None:
continue
g = p.grad.float().clone()
# gate_up_proj grads are in interleaved layout, de-interleave to match orig
if "gate_up_proj" in n:
g = deinterleave_gate_up(g)
grads_patched[n] = g
loss_patched = out_patched.loss.item()
# ---------- Compare ----------
assert abs(loss_orig - loss_patched) < 0.5, (
f"Loss mismatch: orig={loss_orig:.4f}, patched={loss_patched:.4f}"
)
# All parameters with gradients in original should have them in patched
missing = set(grads_orig.keys()) - set(grads_patched.keys())
assert not missing, f"Missing gradients in patched model: {missing}"
# Compare gradient values
# bf16 with different GEMM impls (cuBLAS vs CUTLASS) can diverge,
# so use generous tolerance: flag only if both rel >10% AND abs >1e-2
mismatches = []
for name in grads_orig:
if name not in grads_patched:
continue
g_orig = grads_orig[name]
g_patched = grads_patched[name]
max_diff = (g_orig - g_patched).abs().max().item()
rel_diff = max_diff / (g_orig.abs().max().item() + 1e-8)
if rel_diff > 0.1 and max_diff > 1e-2:
mismatches.append(
f" {name}: max_abs_diff={max_diff:.6f}, rel_diff={rel_diff:.4f}"
)
assert not mismatches, (
"Gradient mismatches (rel_diff > 10% and abs_diff > 1e-2):\n"
+ "\n".join(mismatches)
)
def test_router_weights_receive_gradients(self):
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
out = model(input_ids, labels=input_ids)
out.loss.backward()
gate_grads_found = False
for name, param in model.named_parameters():
if "gate" in name and "weight" in name:
gate_grads_found = True
assert param.grad is not None, f"No gradient for router: {name}"
assert param.grad.abs().max() > 0, f"Zero gradient for router: {name}"
assert gate_grads_found, "No gate.weight parameters found in model"
class TestSonicMoETrainingConvergence:
"""Verify loss decreases during training with SonicMoE."""
def teardown_method(self):
_unpatch_sonicmoe()
def test_loss_decreases(self):
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
losses = []
for step in range(30):
out = model(input_ids, labels=input_ids)
loss = out.loss
assert not math.isnan(loss.item()), f"NaN loss at step {step}"
assert not math.isinf(loss.item()), f"Inf loss at step {step}"
losses.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
assert losses[-1] < losses[0], (
f"Loss did not decrease: first={losses[0]:.4f}, last={losses[-1]:.4f}"
)
def test_expert_weights_update(self):
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
model = AutoModelForCausalLM.from_config(config).cuda().bfloat16()
patch_sonicmoe("qwen3_moe")
_interleave_gate_up_weights(model)
# Snapshot expert weights before training
expert_weights_before = {}
for name, param in model.named_parameters():
if "experts" in name:
expert_weights_before[name] = param.data.clone()
assert expert_weights_before, "No expert parameters found"
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for _ in range(5):
out = model(input_ids, labels=input_ids)
out.loss.backward()
optimizer.step()
optimizer.zero_grad()
# Check that expert weights changed
changed = 0
for name, param in model.named_parameters():
if name in expert_weights_before:
if not torch.equal(param.data, expert_weights_before[name]):
changed += 1
assert changed > 0, "No expert weights changed after 5 training steps"