feat: add sonicmoe (#3411)
* feat: add sonicmoe * feat: add torch compile for routing * feat: add routing smoke test * feat: add qwen3_5_moe, qwen3_vl_moe, qwen3_omni_moe * fix: disable mlp kernel for sonicmoe too * feat: update to sonicmoe release * chore: update import following new sonicmoe changes * feat: update handling for blackwell * feat: add sonicmoe e2e test * fix: installation for updated sonicmoe * fix: git commit * fix: ignore py req and fix metadata * fix: increase min hidden size to match sonicmoe kernel min * fix: attempt properly interleave and handle unpatch mid-test * chore: refactor teardown better * chore: refactor to re-use rearrange * fix: add idempotency guard * fix: address comments on CI memory and interleave * fix: tests grad, param doublewrapped
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
Unit tests for scattermoe-lora code-review fixes.
|
||||
|
||||
Tests cover:
|
||||
- KernelsArgs validator: disable_mlp_kernel_scattermoe
|
||||
- KernelsArgs validator: disable_mlp_kernel
|
||||
- CPU_Offloaded_Gradient_Checkpointer: tuple vs plain tensor backward
|
||||
- ParallelExperts: scaling=0.0 not treated as falsy
|
||||
- single2scatter: non-aligned K/N dimensions
|
||||
@@ -20,12 +20,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
# ============================================================================
|
||||
# 1. KernelsArgs: disable_mlp_kernel_scattermoe validator
|
||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestKernelsArgsValidator:
|
||||
"""Test that disable_mlp_kernel_scattermoe sets both flags correctly.
|
||||
"""Test that disable_mlp_kernel sets both flags correctly.
|
||||
|
||||
These tests call the validator classmethod directly on raw dicts,
|
||||
since lora_mlp_kernel / mlp_kernel are not declared model fields.
|
||||
@@ -40,7 +40,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
assert result["mlp_kernel"] is False
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestKernelsArgsValidator:
|
||||
"use_kernels": True,
|
||||
"use_scattermoe": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["mlp_kernel"] is False
|
||||
# lora_mlp_kernel was not in data, should not be added
|
||||
assert "lora_mlp_kernel" not in result
|
||||
@@ -66,7 +66,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": True,
|
||||
"lora_mlp_kernel": False,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is False
|
||||
|
||||
def test_no_change_when_scattermoe_disabled(self):
|
||||
@@ -78,7 +78,7 @@ class TestKernelsArgsValidator:
|
||||
"use_scattermoe": False,
|
||||
"lora_mlp_kernel": True,
|
||||
}
|
||||
result = KernelsArgs.disable_mlp_kernel_scattermoe(data)
|
||||
result = KernelsArgs.disable_mlp_kernel(data)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user