update for transformers v5 for experts parameters and compose with moe kernels
This commit is contained in:
@@ -2,14 +2,13 @@ import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
@@ -96,16 +95,21 @@ def _build_bailing_model():
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
config = Llama4TextConfig(
|
||||
# Build config without __post_init__ validation (works around a
|
||||
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
||||
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
||||
config.__dict__.update(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
num_hidden_layers=2,
|
||||
hidden_act="silu",
|
||||
layer_types=None,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
@@ -148,6 +152,38 @@ def _build_mixtral_model():
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_qwen35_moe_model():
|
||||
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
||||
Qwen3_5MoeTextConfig,
|
||||
)
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeSparseMoeBlock,
|
||||
)
|
||||
|
||||
config = Qwen3_5MoeTextConfig(
|
||||
hidden_size=16,
|
||||
moe_intermediate_size=32,
|
||||
shared_expert_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
layer = Qwen3_5MoeSparseMoeBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
@@ -194,7 +230,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
||||
self.assertFalse(
|
||||
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
||||
)
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
@@ -209,7 +247,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
@@ -224,33 +264,130 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
self.assertTrue(
|
||||
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
||||
)
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_respects_native_forward(self):
|
||||
def test_mixtral_adapter_patches_router_not_forward(self):
|
||||
"""Verify that aux-free patches the router (gate) only, and the
|
||||
v5 block forward signature (single tensor return) is preserved."""
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
baseline_out, baseline_logits = layer(hidden.clone())
|
||||
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
patched_out, patched_logits = layer(hidden.clone())
|
||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
||||
# Gate should be patched, not the block forward
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
|
||||
# v5 block forward returns a single tensor (not a tuple with logits)
|
||||
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_mixtral_adapter_bias_affects_selection(self):
|
||||
"""When bias is large for one expert, it should be selected more often."""
|
||||
model, layer = _build_mixtral_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Set a large bias for expert 0 to force its selection
|
||||
layer._afb_bias.zero_()
|
||||
layer._afb_bias[0] = 10.0
|
||||
|
||||
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
||||
num_tokens = 2 * 8 # batch * seq
|
||||
layer(hidden)
|
||||
|
||||
# With top_k=2, expert 0 should appear in every token's selection
|
||||
# (once per token = num_tokens counts, not num_tokens * top_k)
|
||||
counts = layer._afb_counts.clone()
|
||||
self.assertEqual(
|
||||
int(counts[0].item()),
|
||||
num_tokens,
|
||||
msg="Expert 0 should be selected for every token when heavily biased",
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
||||
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
||||
output includes shared expert contribution."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
# Gate should be patched
|
||||
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
||||
self.assertTrue(getattr(layer, "_afb_patched", False))
|
||||
# Shared expert should be unmodified
|
||||
self.assertTrue(hasattr(layer, "shared_expert"))
|
||||
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
||||
|
||||
# Forward should return a single tensor (shared + routed)
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 3, hidden_size)
|
||||
out = layer(hidden)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
self.assertEqual(out.shape, hidden.shape)
|
||||
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
def test_qwen35_moe_adapter_bias_updates(self):
|
||||
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
||||
model, layer = _build_qwen35_moe_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
hidden_size = layer.gate.hidden_dim
|
||||
hidden = torch.randn(2, 4, hidden_size)
|
||||
layer(hidden)
|
||||
|
||||
# Bias should start at zero
|
||||
self.assertTrue(
|
||||
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
||||
)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
# After callback: counts reset, EMA updated, bias updated
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(
|
||||
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
||||
)
|
||||
|
||||
def test_qwen35_moe_adapter_model_type_matching(self):
|
||||
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
||||
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
||||
|
||||
adapter = Qwen35MoeAdapter()
|
||||
|
||||
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
||||
model_text = SimpleNamespace(
|
||||
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
||||
)
|
||||
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
||||
|
||||
self.assertTrue(adapter.matches(model_moe))
|
||||
self.assertTrue(adapter.matches(model_text))
|
||||
self.assertFalse(adapter.matches(model_other))
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
@@ -266,7 +403,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
||||
dist.init_process_group(
|
||||
backend="gloo", init_method=init_method, world_size=1, rank=0
|
||||
)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
@@ -289,7 +428,6 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
@@ -316,6 +454,211 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
self.assertIn("moe_afb/l0_load_max", telemetry)
|
||||
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
||||
|
||||
def test_get_num_experts_v5_attribute_paths(self):
|
||||
"""Verify get_num_experts works with v5 attribute layout where
|
||||
num_experts is on gate/experts sub-modules, not the block."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
||||
block = SimpleNamespace(
|
||||
gate=SimpleNamespace(num_experts=8),
|
||||
experts=SimpleNamespace(num_experts=8),
|
||||
)
|
||||
self.assertEqual(adapter.get_num_experts(block), 8)
|
||||
|
||||
# Also works when num_experts is directly on block
|
||||
block2 = SimpleNamespace(num_experts=4)
|
||||
self.assertEqual(adapter.get_num_experts(block2), 4)
|
||||
|
||||
|
||||
class TestAuxFreeKernelComposition(unittest.TestCase):
|
||||
"""Tests that aux-free bias composes correctly with kernel routing."""
|
||||
|
||||
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
||||
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
# Build a mock MoE block with gate attributes
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline: no bias
|
||||
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
self.assertEqual(scores_base.shape[0], T * top_k)
|
||||
|
||||
# Now register aux-free buffers and set heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
||||
hidden, moe_block
|
||||
)
|
||||
|
||||
# Expert 0 should be selected for every token
|
||||
self.assertTrue(
|
||||
(exp_biased == 0).any(),
|
||||
"Expert 0 should appear in selections when heavily biased",
|
||||
)
|
||||
# Counts should have been accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
# Total counts should equal T * top_k
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_sonicmoe_routing_without_bias_unchanged(self):
|
||||
"""Without _afb_bias, routing should produce identical results."""
|
||||
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
|
||||
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
||||
gate.top_k = top_k
|
||||
gate.num_experts = num_experts
|
||||
gate.norm_topk_prob = True
|
||||
|
||||
moe_block = SimpleNamespace(gate=gate)
|
||||
hidden = torch.randn(4, hidden_dim)
|
||||
|
||||
# Without _afb_bias attribute
|
||||
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
# With _afb_bias = zeros (should be equivalent)
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
||||
|
||||
torch.testing.assert_close(scores1, scores2)
|
||||
torch.testing.assert_close(exp1, exp2)
|
||||
|
||||
@unittest.skipUnless(
|
||||
importlib_util.find_spec("triton") is not None,
|
||||
"triton not installed (required by scattermoe)",
|
||||
)
|
||||
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
||||
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
||||
_softmax_topk_route,
|
||||
)
|
||||
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
hidden_dim = 16
|
||||
T = 6
|
||||
|
||||
gate_weight = torch.randn(num_experts, hidden_dim)
|
||||
base_gate = SimpleNamespace(
|
||||
top_k=top_k,
|
||||
num_experts=num_experts,
|
||||
norm_topk_prob=True,
|
||||
weight=gate_weight,
|
||||
)
|
||||
|
||||
moe_block = SimpleNamespace()
|
||||
hidden = torch.randn(T, hidden_dim)
|
||||
|
||||
# Baseline without bias
|
||||
w_base, e_base, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# With heavy bias on expert 0
|
||||
moe_block._afb_bias = torch.zeros(num_experts)
|
||||
moe_block._afb_bias[0] = 100.0
|
||||
moe_block._afb_counts = torch.zeros(num_experts)
|
||||
|
||||
w_biased, e_biased, _, _ = _softmax_topk_route(
|
||||
moe_block, base_gate, hidden, gate_weight, None
|
||||
)
|
||||
|
||||
# Expert 0 should appear in all selections
|
||||
self.assertTrue((e_biased == 0).any())
|
||||
# Counts accumulated
|
||||
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
||||
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
||||
|
||||
def test_kernel_routing_skips_router_patch(self):
|
||||
"""When a kernel backend has patched the block class, the adapter
|
||||
should skip patching the router (buffers are still registered)."""
|
||||
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
||||
|
||||
adapter = MixtralAdapter()
|
||||
|
||||
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True # SonicMoE marker
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16) # placeholder
|
||||
|
||||
layer = PatchedBlock()
|
||||
self.assertTrue(adapter.uses_kernel_routing(layer))
|
||||
|
||||
# Gate should NOT be patched (kernel handles routing)
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
def test_adapter_buffers_registered_even_with_kernel(self):
|
||||
"""Even when kernel routing is active, aux-free buffers must be
|
||||
registered on the MoE block so the kernel routing can find them."""
|
||||
from axolotl.integrations.aux_free_router.adapters import (
|
||||
LayerHandle,
|
||||
MixtralAdapter,
|
||||
)
|
||||
from axolotl.integrations.aux_free_router.core import (
|
||||
AuxFreeConfig,
|
||||
AuxFreeShim,
|
||||
AuxFreeState,
|
||||
)
|
||||
|
||||
class PatchedBlock(nn.Module):
|
||||
_original_forward = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate = nn.Linear(16, 4, bias=False)
|
||||
self.gate.top_k = 2
|
||||
self.gate.num_experts = 4
|
||||
self.gate.hidden_dim = 16
|
||||
self.experts = nn.Linear(16, 16)
|
||||
|
||||
layer = PatchedBlock()
|
||||
adapter = MixtralAdapter()
|
||||
cfg = AuxFreeConfig()
|
||||
state = AuxFreeState(
|
||||
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
||||
)
|
||||
shim = AuxFreeShim(state=state)
|
||||
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
||||
|
||||
adapter.prepare(layer, handle, shim)
|
||||
|
||||
# Buffers should be registered for kernel routing to use
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
self.assertTrue(hasattr(layer, "_afb_counts"))
|
||||
self.assertTrue(hasattr(layer, "_afb_ema"))
|
||||
# But gate should NOT be patched
|
||||
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user