feat: add sonicmoe fused lora support (#3519)

* feat: add sonicmoe fused lora support

* fix: forgot to add file

* feat: add test

* feat: add lora support for other routes

* fix: add int8 lora support

* fix: add qwen35_moe interleave support

* fix: qwen3_5_moe loss

* chore: lint

* address some pr comments

* fix test imports

* add support matrix for moe kernels [skip ci]

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2026-04-02 19:53:48 +07:00
committed by GitHub
parent 16e32232fb
commit 842fa039dd
16 changed files with 1249 additions and 126 deletions

View File

@@ -51,7 +51,7 @@ def _create_tiny_qwen3_config():
def _interleave_gate_up_weights(model):
"""Interleave all gate_up_proj parameters in-place for SonicMoE."""
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
interleave_gate_up,
)
@@ -80,7 +80,7 @@ class TestSonicMoEForwardCorrectness:
def test_forward_output_matches(self):
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
@@ -117,8 +117,8 @@ class TestSonicMoEGradientCorrectness:
"""Verify all parameter gradients match between original and patched."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.sonicmoe.weight_converter import (
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.weight_converter import (
deinterleave_gate_up,
)
@@ -191,7 +191,7 @@ class TestSonicMoEGradientCorrectness:
"""Verify that router (gate) weights get non-zero gradients."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (1, 16), device="cuda")
@@ -223,7 +223,7 @@ class TestSonicMoETrainingConvergence:
"""Run 30 training steps, verify loss decreases and no NaN/Inf."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")
@@ -254,7 +254,7 @@ class TestSonicMoETrainingConvergence:
"""Verify expert weights change during training (not frozen)."""
from transformers import AutoModelForCausalLM
from axolotl.integrations.kernels.sonicmoe.patch import patch_sonicmoe
from axolotl.integrations.kernels.libs.sonicmoe.patch import patch_sonicmoe
config = _create_tiny_qwen3_config()
input_ids = torch.randint(0, config.vocab_size, (2, 32), device="cuda")