* feat: add sonicmoe * feat: add torch compile for routing * feat: add routing smoke test * feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe * fix: disable mlp kernel for sonicmoe too * feat: update to sonicmoe release * chore: update import following new sonicmoe changes * feat: update handling for blackwell * feat: add sonicmoe e2e test * fix: installation for updated sonicmoe * fix: git commit * fix: ignore py req and fix metadata * fix: increase min hidden size to match sonicmoe kernel min * fix: attempt properly interleave and handle unpatch mid-test * chore: refactor teardown better * chore: refactor to re-use rearrange * fix: add idempotency guard * fix: address comments on CI memory and interleave * fix: tests grad, param doublewrapped
324 lines
12 KiB
Python
324 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright (c) Axolotl AI
|
|
# Licensed under the Apache License, Version 2.0
|
|
|
|
"""
|
|
Unit tests for scattermoe-lora code-review fixes.
|
|
|
|
Tests cover:
|
|
- KernelsArgs validator: disable_mlp_kernel
|
|
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
|
- ParallelExperts: scaling=0.0 not treated as falsy
|
|
- single2scatter: non-aligned K/N dimensions
|
|
- group_compileable: coeff=None accepted
|
|
- HFScatterMoEGatedMLP / ScatterMoEGatedMLP: return value contract
|
|
"""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
# ============================================================================
|
|
# 1. KernelsArgs: disable_mlp_kernel validator
|
|
# ============================================================================
|
|
|
|
|
|
class TestKernelsArgsValidator:
|
|
"""Test that disable_mlp_kernel sets both flags correctly.
|
|
|
|
These tests call the validator classmethod directly on raw dicts,
|
|
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
|
"""
|
|
|
|
def test_disables_lora_mlp_kernel_when_scattermoe(self):
|
|
"""lora_mlp_kernel=True gets set to False when use_scattermoe=True."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
"lora_mlp_kernel": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is False
|
|
assert result["mlp_kernel"] is False
|
|
|
|
def test_mlp_kernel_disabled_without_lora(self):
|
|
"""Even without lora_mlp_kernel, mlp_kernel should be disabled."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["mlp_kernel"] is False
|
|
# lora_mlp_kernel was not in data, should not be added
|
|
assert "lora_mlp_kernel" not in result
|
|
|
|
def test_lora_mlp_kernel_false_unchanged(self):
|
|
"""lora_mlp_kernel=False should stay False (no warning, no change)."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": True,
|
|
"lora_mlp_kernel": False,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is False
|
|
|
|
def test_no_change_when_scattermoe_disabled(self):
|
|
"""When use_scattermoe is not True, nothing should be changed."""
|
|
from axolotl.integrations.kernels.args import KernelsArgs
|
|
|
|
data = {
|
|
"use_kernels": True,
|
|
"use_scattermoe": False,
|
|
"lora_mlp_kernel": True,
|
|
}
|
|
result = KernelsArgs.disable_mlp_kernel(data)
|
|
assert result["lora_mlp_kernel"] is True
|
|
|
|
|
|
class TestParallelExpertsScaling:
|
|
"""Test that scaling=0.0 is preserved and not overridden to 1.0."""
|
|
|
|
def test_scaling_zero_preserved(self):
|
|
"""scaling=0.0 should be passed as 0.0, not replaced with 1.0."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
pe.set_lora(
|
|
lora_A=torch.randn(4, 4),
|
|
lora_B=torch.randn(4, 4),
|
|
scaling=0.0,
|
|
)
|
|
assert pe._lora_scaling == 0.0
|
|
|
|
# Patch parallel_linear_lora to capture the scaling arg
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
# Create dummy routing tensors
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
# Check that scaling=0.0 was passed, not 1.0
|
|
call_kwargs = mock_pll.call_args
|
|
assert (
|
|
call_kwargs.kwargs.get("scaling") == 0.0
|
|
or call_kwargs[1].get("scaling") == 0.0
|
|
), f"Expected scaling=0.0 but got {call_kwargs}"
|
|
|
|
def test_scaling_none_defaults_to_one(self):
|
|
"""scaling=None (no LoRA attached) should default to 1.0."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
# No set_lora called, so _lora_scaling is None
|
|
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
call_kwargs = mock_pll.call_args
|
|
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
|
"scaling"
|
|
)
|
|
assert scaling_val == 1.0, (
|
|
f"Expected scaling=1.0 for None but got {scaling_val}"
|
|
)
|
|
|
|
def test_scaling_positive_preserved(self):
|
|
"""Normal positive scaling should be preserved."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops import (
|
|
ParallelExperts,
|
|
)
|
|
|
|
pe = ParallelExperts(num_experts=2, input_size=4, output_size=4)
|
|
pe.set_lora(
|
|
lora_A=torch.randn(4, 4),
|
|
lora_B=torch.randn(4, 4),
|
|
scaling=0.5,
|
|
)
|
|
|
|
with patch(
|
|
"axolotl.integrations.kernels.libs.scattermoe_lora.lora_ops.parallel_linear_lora"
|
|
) as mock_pll:
|
|
mock_pll.return_value = torch.randn(4, 4)
|
|
pe.forward(
|
|
inputs=torch.randn(2, 4),
|
|
k=1,
|
|
sorted_expert_idxs=torch.tensor([0, 0, 1, 1]),
|
|
sorted_scattered_idxs=torch.tensor([0, 1, 0, 1]),
|
|
expert_offsets=torch.tensor([2, 4]),
|
|
)
|
|
call_kwargs = mock_pll.call_args
|
|
scaling_val = call_kwargs.kwargs.get("scaling") or call_kwargs[1].get(
|
|
"scaling"
|
|
)
|
|
assert scaling_val == 0.5
|
|
|
|
|
|
# ============================================================================
|
|
# 4. single2scatter: non-aligned K/N dimensions (GPU only)
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestSingle2ScatterBounds:
|
|
"""Test single2scatter with non-aligned dimensions."""
|
|
|
|
def test_non_aligned_k(self):
|
|
"""K not a multiple of BLOCK_K should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 100, 128 # K=100 not a multiple of 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
# Verify against manual computation
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_non_aligned_n(self):
|
|
"""N not a multiple of BLOCK_N should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 128, 100 # N=100 not a multiple of 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
def test_non_aligned_both(self):
|
|
"""Both K and N not aligned should produce correct results."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.single import (
|
|
single2scatter,
|
|
)
|
|
|
|
E, K, N = 2, 100, 100 # Neither aligned to 128
|
|
W = torch.randn(E, K, N, device="cuda", dtype=torch.float32)
|
|
X = torch.randn(1, K, device="cuda", dtype=torch.float32)
|
|
expert_idxs = torch.tensor([[0, 1]], device="cuda", dtype=torch.long)
|
|
|
|
Y = single2scatter(X, W, expert_idxs)
|
|
assert Y.shape == (2, N)
|
|
|
|
Y_ref_0 = X[0] @ W[0]
|
|
Y_ref_1 = X[0] @ W[1]
|
|
torch.testing.assert_close(Y[0], Y_ref_0, atol=1e-2, rtol=1e-2)
|
|
torch.testing.assert_close(Y[1], Y_ref_1, atol=1e-2, rtol=1e-2)
|
|
|
|
|
|
# ============================================================================
|
|
# 5. group_compileable: coeff=None accepted
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
class TestGroupCoeffNone:
|
|
"""Test that group() works with coeff=None."""
|
|
|
|
def test_group_with_none_coeff(self):
|
|
"""group() should accept coeff=None without errors."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
|
|
|
M, K = 4, 32
|
|
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
|
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
|
|
|
# This should not raise a TypeError
|
|
Y = group(A, sorted_expert_idxs, coeff=None, fan_out=1)
|
|
assert Y.shape == (M, K)
|
|
|
|
def test_group_with_coeff(self):
|
|
"""group() should also work with actual coeff values."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import group
|
|
|
|
M, K = 4, 32
|
|
A = torch.randn(M, K, device="cuda", dtype=torch.float32)
|
|
sorted_expert_idxs = torch.tensor([0, 1, 2, 3], device="cuda", dtype=torch.long)
|
|
coeff = torch.ones(M, device="cuda", dtype=torch.float32) * 0.5
|
|
|
|
Y = group(A, sorted_expert_idxs, coeff=coeff, fan_out=1)
|
|
assert Y.shape == (M, K)
|
|
|
|
|
|
# ============================================================================
|
|
# 6. Layer return value contracts
|
|
# ============================================================================
|
|
|
|
|
|
class TestLayerReturnValues:
|
|
"""Test that layer forward methods return the correct types."""
|
|
|
|
def test_hf_scatter_moe_returns_single_tensor(self):
|
|
"""HFScatterMoEGatedMLP.forward should return a single tensor, not a tuple."""
|
|
pytest.importorskip("triton")
|
|
# Verify the forward method signature and return annotation
|
|
import inspect
|
|
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
HFScatterMoEGatedMLP,
|
|
)
|
|
|
|
sig = inspect.signature(HFScatterMoEGatedMLP.forward)
|
|
# It's a staticmethod taking (self, layer_input)
|
|
params = list(sig.parameters.keys())
|
|
assert "self" in params
|
|
assert "layer_input" in params
|
|
|
|
def test_scatter_moe_gated_mlp_docstring_no_router_logits(self):
|
|
"""ScatterMoEGatedMLP.forward docstring should not mention router logits as return."""
|
|
pytest.importorskip("triton")
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
ScatterMoEGatedMLP,
|
|
)
|
|
|
|
docstring = ScatterMoEGatedMLP.forward.__doc__
|
|
assert docstring is not None
|
|
# The docstring should mention output tensor but NOT router logits
|
|
assert "Output tensor" in docstring or "output tensor" in docstring.lower()
|
|
assert "Router logits" not in docstring, (
|
|
"Docstring should not mention 'Router logits' in Returns section"
|
|
)
|