feat: add sonicmoe (#3411)

* feat: add sonicmoe

* feat: add torch compile for routing

* feat: add routing smoke test

* feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe

* fix: disable mlp kernel for sonicmoe too

* feat: update to sonicmoe release

* chore: update import following new sonicmoe changes

* feat: update handling for blackwell

* feat: add sonicmoe e2e test

* fix: installation for updated sonicmoe

* fix: git commit

* fix: ignore py req and fix metadata

* fix: increase min hidden size to match sonicmoe kernel min

* fix: attempt properly interleave and handle unpatch mid-test

* chore: refactor teardown better

* chore: refactor to re-use rearrange

* fix: add idempotency guard

* fix: address comments on CI memory and interleave

* fix: tests grad, param doublewrapped
This commit is contained in:
NanoCode012
2026-03-06 01:43:31 +07:00
committed by GitHub
parent 1eaf4d7418
commit 6a8baf8fa7
12 changed files with 1698 additions and 42 deletions

View File

@@ -6,7 +6,7 @@
Unit tests for scattermoe-lora code-review fixes.
Tests cover:
- KernelsArgs validator: disable_mlp_kernel_scattermoe
- KernelsArgs validator: disable_mlp_kernel
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
- ParallelExperts: scaling=0.0 not treated as falsy
- single2scatter: non-aligned K/N dimensions
@@ -20,12 +20,12 @@ import pytest
import torch
# ============================================================================
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
# 1. KernelsArgs: disable_mlp_kernel validator
# ============================================================================
class TestKernelsArgsValidator:
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
"""Test that disable_mlp_kernel sets both flags correctly.
These tests call the validator classmethod directly on raw dicts,
since lora_mlp_kernel / mlp_kernel are not declared model fields.
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
assert result["mlp_kernel"] is False
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
"use_kernels": True,
"use_scattermoe": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["mlp_kernel"] is False
# lora_mlp_kernel was not in data, should not be added
assert "lora_mlp_kernel" not in result
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": True,
"lora_mlp_kernel": False,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is False
def test_no_change_when_scattermoe_disabled(self):
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
"use_scattermoe": False,
"lora_mlp_kernel": True,
}
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
result = KernelsArgs.disable_mlp_kernel(data)
assert result["lora_mlp_kernel"] is True