Files
axolotl/tests/integrations/test_scattermoe_lora.py
NanoCode012 6a8baf8fa7 feat: add sonicmoe (#3411)
* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
2026-03-05 13:43:31 -05:00

324 lines
12 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
- group_compileable: coeff=None accepted
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
"""
from unittest.mock import patch
import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
"""
def test_disables_lora_mlp_kernel_when_scattermoe(self):
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
def test_mlp_kernel_disabled_without_lora(self):
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
def test_lora_mlp_kernel_false_unchanged(self):
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
"""When use_scattermoe is not True, nothing should be changed."""
from axolotl.integrations.kernels.args import KernelsArgs
data = {
"use_kernels": True,
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is True
class TestParallelExpertsScaling:
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
def test_scaling_zero_preserved(self):
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.0,
)
assert pe._lora_scaling == 0.0
# Patch parallel_linear_lora to capture the scaling arg
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
# Create dummy routing tensors
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
# Check that scaling=0.0 was passed, not 1.0
call_kwargs = mock_pll.call_args
assert (
call_kwargs.kwargs.get("scaling") == 0.0
or call_kwargs[1].get("scaling") == 0.0
), f"Expected scaling=0.0 but got {call_kwargs}"
def test_scaling_none_defaults_to_one(self):
"""scaling=None (no LoRA attached) should default to 1.0."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
# No set_lora called, so _lora_scaling is None
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 1.0, (
f"Expected scaling=1.0 for None but got {scaling_val}"
)
def test_scaling_positive_preserved(self):
"""Normal positive scaling should be preserved."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
ParallelExperts,
)
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
pe.set_lora(
lora_A=torch.randn(4, 4),
lora_B=torch.randn(4, 4),
scaling=0.5,
)
with patch(
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
) as mock_pll:
mock_pll.return_value = torch.randn(4, 4)
pe.forward(
inputs=torch.randn(2, 4),
k=1,
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
expert_offsets=torch.tensor([2, 4]),
)
call_kwargs = mock_pll.call_args
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
"scaling"
)
assert scaling_val == 0.5
# ============================================================================
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestSingle2ScatterBounds:
"""Test single2scatter with non-aligned dimensions."""
def test_non_aligned_k(self):
"""K not a multiple of BLOCK_K should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
# Verify against manual computation
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_n(self):
"""N not a multiple of BLOCK_N should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
def test_non_aligned_both(self):
"""Both K and N not aligned should produce correct results."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
single2scatter,
)
E, K, N = 2, 100, 100 # Neither aligned to 128
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
Y = single2scatter(X, W, expert_idxs)
assert Y.shape == (2, N)
Y_ref_0 = X[0] @ W[0]
Y_ref_1 = X[0] @ W[1]
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
# ============================================================================
# 5. group_compileable: coeff=None accepted
# ============================================================================
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
class TestGroupCoeffNone:
"""Test that group() works with coeff=None."""
def test_group_with_none_coeff(self):
"""group() should accept coeff=None without errors."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
# This should not raise a TypeError
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
assert Y.shape == (M, K)
def test_group_with_coeff(self):
"""group() should also work with actual coeff values."""
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
M, K = 4, 32
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
assert Y.shape == (M, K)
# ============================================================================
# 6. Layer return value contracts
# ============================================================================
class TestLayerReturnValues:
"""Test that layer forward methods return the correct types."""
def test_hf_scatter_moe_returns_single_tensor(self):
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
pytest.importorskip("triton")
# Verify the forward method signature and return annotation
import inspect
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
HFScatterMoEGatedMLP,
)
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
# It's a staticmethod taking (self, layer_input)
params = list(sig.parameters.keys())
assert "self" in params
assert "layer_input" in params
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
pytest.importorskip("triton")
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
ScatterMoEGatedMLP,
)
docstring = ScatterMoEGatedMLP.forward.__doc__
assert docstring is not None
# The docstring should mention output tensor but NOT router logits
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
assert "Router logits" not in docstring, (
"Docstring should not mention 'Router logits' in Returns section"
)