update for transformers v5 for experts parameters and compose with moe kernels

This commit is contained in:
Wing Lian
2026-03-22 11:52:34 -04:00
parent 4009a2ba5f
commit 5acb1b0ade
5 changed files with 650 additions and 162 deletions

View File

@@ -2,14 +2,13 @@ import os
import sys
import tempfile
import unittest
from importlib import util as importlib_util
from pathlib import Path
from types import SimpleNamespace
import torch
import torch.distributed as dist
import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
@@ -96,16 +95,21 @@ def _build_bailing_model():
def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig(
# Build config without __post_init__ validation (works around a
# huggingface_hub strict-dataclass type mismatch for layer_types).
config = object.__new__(__import__("transformers").Llama4TextConfig)
config.__dict__.update(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
num_hidden_layers=2,
hidden_act="silu",
layer_types=None,
)
layer = Llama4TextMoe(config)
@@ -148,6 +152,38 @@ def _build_mixtral_model():
return DummyModel(layer), layer
def _build_qwen35_moe_model():
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
Qwen3_5MoeTextConfig,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeSparseMoeBlock,
)
config = Qwen3_5MoeTextConfig(
hidden_size=16,
moe_intermediate_size=32,
shared_expert_intermediate_size=32,
num_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
num_hidden_layers=2,
)
layer = Qwen3_5MoeSparseMoeBlock(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="qwen3_5_moe")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
if args is None:
args = SimpleNamespace(logging_steps=1)
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
self.assertFalse(
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
)
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
self.assertTrue(
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
)
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self):
def test_mixtral_adapter_patches_router_not_forward(self):
"""Verify that aux-free patches the router (gate) only, and the
v5 block forward signature (single tensor return) is preserved."""
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone())
self.assertTrue(torch.allclose(baseline_out, patched_out))
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
# Gate should be patched, not the block forward
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# v5 block forward returns a single tensor (not a tuple with logits)
hidden = torch.randn(2, 3, layer.config.hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_mixtral_adapter_bias_affects_selection(self):
"""When bias is large for one expert, it should be selected more often."""
model, layer = _build_mixtral_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Set a large bias for expert 0 to force its selection
layer._afb_bias.zero_()
layer._afb_bias[0] = 10.0
hidden = torch.randn(2, 8, layer.config.hidden_size)
num_tokens = 2 * 8 # batch * seq
layer(hidden)
# With top_k=2, expert 0 should appear in every token's selection
# (once per token = num_tokens counts, not num_tokens * top_k)
counts = layer._afb_counts.clone()
self.assertEqual(
int(counts[0].item()),
num_tokens,
msg="Expert 0 should be selected for every token when heavily biased",
)
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
output includes shared expert contribution."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
# Gate should be patched
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
self.assertTrue(getattr(layer, "_afb_patched", False))
# Shared expert should be unmodified
self.assertTrue(hasattr(layer, "shared_expert"))
self.assertTrue(hasattr(layer, "shared_expert_gate"))
# Forward should return a single tensor (shared + routed)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 3, hidden_size)
out = layer(hidden)
self.assertIsInstance(out, torch.Tensor)
self.assertEqual(out.shape, hidden.shape)
# Counts should have been accumulated
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
def test_qwen35_moe_adapter_bias_updates(self):
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
model, layer = _build_qwen35_moe_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
hidden_size = layer.gate.hidden_dim
hidden = torch.randn(2, 4, hidden_size)
layer(hidden)
# Bias should start at zero
self.assertTrue(
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
)
_run_callback(plugin, cfg)
# After callback: counts reset, EMA updated, bias updated
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
)
def test_qwen35_moe_adapter_model_type_matching(self):
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
adapter = Qwen35MoeAdapter()
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
model_text = SimpleNamespace(
config=SimpleNamespace(model_type="qwen3_5_moe_text")
)
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
self.assertTrue(adapter.matches(model_moe))
self.assertTrue(adapter.matches(model_text))
self.assertFalse(adapter.matches(model_other))
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
dist.init_process_group(
backend="gloo", init_method=init_method, world_size=1, rank=0
)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertIn("moe_afb/l0_load_max", telemetry)
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
def test_get_num_experts_v5_attribute_paths(self):
"""Verify get_num_experts works with v5 attribute layout where
num_experts is on gate/experts sub-modules, not the block."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
block = SimpleNamespace(
gate=SimpleNamespace(num_experts=8),
experts=SimpleNamespace(num_experts=8),
)
self.assertEqual(adapter.get_num_experts(block), 8)
# Also works when num_experts is directly on block
block2 = SimpleNamespace(num_experts=4)
self.assertEqual(adapter.get_num_experts(block2), 4)
class TestAuxFreeKernelComposition(unittest.TestCase):
"""Tests that aux-free bias composes correctly with kernel routing."""
def test_sonicmoe_softmax_routing_with_afb_bias(self):
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
# Build a mock MoE block with gate attributes
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(T, hidden_dim)
# Baseline: no bias
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
hidden, moe_block
)
self.assertEqual(scores_base.shape[0], T * top_k)
# Now register aux-free buffers and set heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
hidden, moe_block
)
# Expert 0 should be selected for every token
self.assertTrue(
(exp_biased == 0).any(),
"Expert 0 should appear in selections when heavily biased",
)
# Counts should have been accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
# Total counts should equal T * top_k
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_sonicmoe_routing_without_bias_unchanged(self):
"""Without _afb_bias, routing should produce identical results."""
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
num_experts = 4
top_k = 2
hidden_dim = 16
gate = nn.Linear(hidden_dim, num_experts, bias=False)
gate.top_k = top_k
gate.num_experts = num_experts
gate.norm_topk_prob = True
moe_block = SimpleNamespace(gate=gate)
hidden = torch.randn(4, hidden_dim)
# Without _afb_bias attribute
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
# With _afb_bias = zeros (should be equivalent)
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_counts = torch.zeros(num_experts)
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
torch.testing.assert_close(scores1, scores2)
torch.testing.assert_close(exp1, exp2)
@unittest.skipUnless(
importlib_util.find_spec("triton") is not None,
"triton not installed (required by scattermoe)",
)
def test_scattermoe_softmax_routing_with_afb_bias(self):
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
_softmax_topk_route,
)
num_experts = 4
top_k = 2
hidden_dim = 16
T = 6
gate_weight = torch.randn(num_experts, hidden_dim)
base_gate = SimpleNamespace(
top_k=top_k,
num_experts=num_experts,
norm_topk_prob=True,
weight=gate_weight,
)
moe_block = SimpleNamespace()
hidden = torch.randn(T, hidden_dim)
# Baseline without bias
w_base, e_base, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# With heavy bias on expert 0
moe_block._afb_bias = torch.zeros(num_experts)
moe_block._afb_bias[0] = 100.0
moe_block._afb_counts = torch.zeros(num_experts)
w_biased, e_biased, _, _ = _softmax_topk_route(
moe_block, base_gate, hidden, gate_weight, None
)
# Expert 0 should appear in all selections
self.assertTrue((e_biased == 0).any())
# Counts accumulated
self.assertGreater(moe_block._afb_counts[0].item(), 0)
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
def test_kernel_routing_skips_router_patch(self):
"""When a kernel backend has patched the block class, the adapter
should skip patching the router (buffers are still registered)."""
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
adapter = MixtralAdapter()
# Create a mock layer whose class has _original_forward (SonicMoE marker)
class PatchedBlock(nn.Module):
_original_forward = True # SonicMoE marker
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16) # placeholder
layer = PatchedBlock()
self.assertTrue(adapter.uses_kernel_routing(layer))
# Gate should NOT be patched (kernel handles routing)
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
def test_adapter_buffers_registered_even_with_kernel(self):
"""Even when kernel routing is active, aux-free buffers must be
registered on the MoE block so the kernel routing can find them."""
from axolotl.integrations.aux_free_router.adapters import (
LayerHandle,
MixtralAdapter,
)
from axolotl.integrations.aux_free_router.core import (
AuxFreeConfig,
AuxFreeShim,
AuxFreeState,
)
class PatchedBlock(nn.Module):
_original_forward = True
def __init__(self):
super().__init__()
self.gate = nn.Linear(16, 4, bias=False)
self.gate.top_k = 2
self.gate.num_experts = 4
self.gate.hidden_dim = 16
self.experts = nn.Linear(16, 16)
layer = PatchedBlock()
adapter = MixtralAdapter()
cfg = AuxFreeConfig()
state = AuxFreeState(
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
)
shim = AuxFreeShim(state=state)
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
adapter.prepare(layer, handle, shim)
# Buffers should be registered for kernel routing to use
self.assertTrue(hasattr(layer, "_afb_bias"))
self.assertTrue(hasattr(layer, "_afb_counts"))
self.assertTrue(hasattr(layer, "_afb_ema"))
# But gate should NOT be patched
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
if __name__ == "__main__":
unittest.main()