667 lines
23 KiB
Python
667 lines
23 KiB
Python
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from importlib import util as importlib_util
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
|
|
|
|
|
def _cfg(**overrides):
|
|
defaults = dict(
|
|
moe_balance_type="noaux_tc",
|
|
moe_update_rate=0.1,
|
|
moe_update_momentum=0.9,
|
|
moe_bias_cap=2.0,
|
|
moe_afb_warmup_steps=0,
|
|
moe_bias_sync_group="world",
|
|
expert_parallel_size=1,
|
|
)
|
|
defaults.update(overrides)
|
|
return SimpleNamespace(**defaults)
|
|
|
|
|
|
def _load_bailing_modules():
|
|
repo_dir = snapshot_download(
|
|
repo_id="inclusionAI/Ring-mini-2.0",
|
|
allow_patterns=[
|
|
"configuration_bailing_moe_v2.py",
|
|
"modeling_bailing_moe_v2.py",
|
|
"__init__.py",
|
|
],
|
|
)
|
|
repo = Path(repo_dir)
|
|
config_path = repo / "configuration_bailing_moe_v2.py"
|
|
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
|
|
|
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
|
if config_name not in sys.modules:
|
|
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
|
module = importlib_util.module_from_spec(spec)
|
|
sys.modules[config_name] = module
|
|
sys.modules["configuration_bailing_moe_v2"] = module
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
config_module = sys.modules[config_name]
|
|
|
|
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
|
if modeling_name not in sys.modules:
|
|
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
|
module = importlib_util.module_from_spec(spec)
|
|
sys.modules[modeling_name] = module
|
|
sys.modules["modeling_bailing_moe_v2"] = module
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
modeling_module = sys.modules[modeling_name]
|
|
|
|
BailingMoeV2Config = config_module.BailingMoeV2Config
|
|
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
|
|
|
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
|
|
|
|
|
def _build_bailing_model():
|
|
BailingConfig, BailingBlock = _load_bailing_modules()
|
|
config = BailingConfig(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
moe_intermediate_size=32,
|
|
num_experts=4,
|
|
num_shared_experts=None,
|
|
num_experts_per_tok=2,
|
|
n_group=1,
|
|
topk_group=1,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
block = BailingBlock(config)
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, layer):
|
|
super().__init__()
|
|
self.block = layer
|
|
self.config = SimpleNamespace(model_type="bailing_moe")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.block(hidden_states)
|
|
|
|
return DummyModel(block), block
|
|
|
|
|
|
def _build_llama4_model():
|
|
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
|
|
|
# Build config without __post_init__ validation (works around a
|
|
# huggingface_hub strict-dataclass type mismatch for layer_types).
|
|
config = object.__new__(__import__("transformers").Llama4TextConfig)
|
|
config.__dict__.update(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
num_local_experts=4,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
num_experts_per_tok=2,
|
|
num_hidden_layers=2,
|
|
hidden_act="silu",
|
|
layer_types=None,
|
|
)
|
|
layer = Llama4TextMoe(config)
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, moe_layer):
|
|
super().__init__()
|
|
self.moe = moe_layer
|
|
self.config = SimpleNamespace(model_type="llama4")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.moe(hidden_states)
|
|
|
|
return DummyModel(layer), layer
|
|
|
|
|
|
def _build_mixtral_model():
|
|
from transformers import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
|
|
config = MixtralConfig(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
num_local_experts=4,
|
|
num_experts_per_tok=2,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
)
|
|
layer = MixtralSparseMoeBlock(config)
|
|
layer.config = config
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, moe_layer):
|
|
super().__init__()
|
|
self.moe = moe_layer
|
|
self.config = SimpleNamespace(model_type="mixtral")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.moe(hidden_states)
|
|
|
|
return DummyModel(layer), layer
|
|
|
|
|
|
def _build_qwen35_moe_model():
|
|
from transformers.models.qwen3_5_moe.configuration_qwen3_5_moe import (
|
|
Qwen3_5MoeTextConfig,
|
|
)
|
|
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
|
Qwen3_5MoeSparseMoeBlock,
|
|
)
|
|
|
|
config = Qwen3_5MoeTextConfig(
|
|
hidden_size=16,
|
|
moe_intermediate_size=32,
|
|
shared_expert_intermediate_size=32,
|
|
num_experts=4,
|
|
num_experts_per_tok=2,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
num_hidden_layers=2,
|
|
)
|
|
layer = Qwen3_5MoeSparseMoeBlock(config)
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, moe_layer):
|
|
super().__init__()
|
|
self.moe = moe_layer
|
|
self.config = SimpleNamespace(model_type="qwen3_5_moe")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.moe(hidden_states)
|
|
|
|
return DummyModel(layer), layer
|
|
|
|
|
|
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
|
if args is None:
|
|
args = SimpleNamespace(logging_steps=1)
|
|
if state is None:
|
|
state = SimpleNamespace(global_step=1, log_history=[])
|
|
if control is None:
|
|
control = SimpleNamespace(
|
|
should_log=False,
|
|
should_evaluate=False,
|
|
should_save=False,
|
|
should_training_stop=False,
|
|
)
|
|
|
|
class DummyTrainer:
|
|
def __init__(self, state_obj, control_obj):
|
|
self.state = state_obj
|
|
self.control = control_obj
|
|
|
|
def log(self, logs):
|
|
output = dict(logs)
|
|
output["step"] = self.state.global_step
|
|
self.state.log_history.append(output)
|
|
self.control.should_log = True
|
|
|
|
dummy_trainer = DummyTrainer(state, control)
|
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
|
assert callbacks, "expected aux-free callback to be registered"
|
|
callback = callbacks[0]
|
|
callback.on_step_end(args=args, state=state, control=control)
|
|
return state, control
|
|
|
|
|
|
class TestAuxFreeAdapters(unittest.TestCase):
|
|
def test_bailing_adapter_updates_counts_and_bias(self):
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertTrue(hasattr(block, "_afb_bias"))
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
|
|
|
_run_callback(plugin, cfg)
|
|
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
|
self.assertFalse(
|
|
torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))
|
|
)
|
|
|
|
def test_llama4_adapter_biases_router_selection(self):
|
|
model, layer = _build_llama4_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
|
hidden = torch.randn(2, 4, layer.hidden_dim)
|
|
layer(hidden)
|
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
|
|
_run_callback(plugin, cfg)
|
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
|
self.assertFalse(
|
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
|
)
|
|
|
|
def test_bias_warmup_respected(self):
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg(moe_afb_warmup_steps=2)
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
def _step():
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
_run_callback(plugin, cfg)
|
|
|
|
# Warmup steps should leave bias untouched.
|
|
_step()
|
|
self.assertTrue(
|
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
|
)
|
|
|
|
_step()
|
|
self.assertTrue(
|
|
torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))
|
|
)
|
|
|
|
# Third step exceeds warmup -> bias should update.
|
|
_step()
|
|
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
|
|
|
def test_mixtral_adapter_patches_router_not_forward(self):
|
|
"""Verify that aux-free patches the router (gate) only, and the
|
|
v5 block forward signature (single tensor return) is preserved."""
|
|
model, layer = _build_mixtral_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
# Gate should be patched, not the block forward
|
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
|
|
|
# v5 block forward returns a single tensor (not a tuple with logits)
|
|
hidden = torch.randn(2, 3, layer.config.hidden_size)
|
|
out = layer(hidden)
|
|
self.assertIsInstance(out, torch.Tensor)
|
|
self.assertEqual(out.shape, hidden.shape)
|
|
|
|
# Counts should have been accumulated
|
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
_run_callback(plugin, cfg)
|
|
|
|
def test_mixtral_adapter_bias_affects_selection(self):
|
|
"""When bias is large for one expert, it should be selected more often."""
|
|
model, layer = _build_mixtral_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
# Set a large bias for expert 0 to force its selection
|
|
layer._afb_bias.zero_()
|
|
layer._afb_bias[0] = 10.0
|
|
|
|
hidden = torch.randn(2, 8, layer.config.hidden_size)
|
|
num_tokens = 2 * 8 # batch * seq
|
|
layer(hidden)
|
|
|
|
# With top_k=2, expert 0 should appear in every token's selection
|
|
# (once per token = num_tokens counts, not num_tokens * top_k)
|
|
counts = layer._afb_counts.clone()
|
|
self.assertEqual(
|
|
int(counts[0].item()),
|
|
num_tokens,
|
|
msg="Expert 0 should be selected for every token when heavily biased",
|
|
)
|
|
|
|
def test_qwen35_moe_adapter_patches_router_and_preserves_shared_expert(self):
|
|
"""Verify Qwen 3.5 MoE: router is patched, shared expert is untouched,
|
|
output includes shared expert contribution."""
|
|
model, layer = _build_qwen35_moe_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
# Gate should be patched
|
|
self.assertTrue(getattr(layer.gate, "_afb_patched", False))
|
|
self.assertTrue(getattr(layer, "_afb_patched", False))
|
|
# Shared expert should be unmodified
|
|
self.assertTrue(hasattr(layer, "shared_expert"))
|
|
self.assertTrue(hasattr(layer, "shared_expert_gate"))
|
|
|
|
# Forward should return a single tensor (shared + routed)
|
|
hidden_size = layer.gate.hidden_dim
|
|
hidden = torch.randn(2, 3, hidden_size)
|
|
out = layer(hidden)
|
|
self.assertIsInstance(out, torch.Tensor)
|
|
self.assertEqual(out.shape, hidden.shape)
|
|
|
|
# Counts should have been accumulated
|
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
|
|
def test_qwen35_moe_adapter_bias_updates(self):
|
|
"""Full cycle: forward → callback → verify bias update for Qwen 3.5 MoE."""
|
|
model, layer = _build_qwen35_moe_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
hidden_size = layer.gate.hidden_dim
|
|
hidden = torch.randn(2, 4, hidden_size)
|
|
layer(hidden)
|
|
|
|
# Bias should start at zero
|
|
self.assertTrue(
|
|
torch.allclose(layer._afb_bias, torch.zeros_like(layer._afb_bias))
|
|
)
|
|
|
|
_run_callback(plugin, cfg)
|
|
|
|
# After callback: counts reset, EMA updated, bias updated
|
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
|
self.assertFalse(
|
|
torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))
|
|
)
|
|
|
|
def test_qwen35_moe_adapter_model_type_matching(self):
|
|
"""Verify the adapter matches both qwen3_5_moe and qwen3_5_moe_text."""
|
|
from axolotl.integrations.aux_free_router.adapters import Qwen35MoeAdapter
|
|
|
|
adapter = Qwen35MoeAdapter()
|
|
|
|
model_moe = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_5_moe"))
|
|
model_text = SimpleNamespace(
|
|
config=SimpleNamespace(model_type="qwen3_5_moe_text")
|
|
)
|
|
model_other = SimpleNamespace(config=SimpleNamespace(model_type="qwen3_moe"))
|
|
|
|
self.assertTrue(adapter.matches(model_moe))
|
|
self.assertTrue(adapter.matches(model_text))
|
|
self.assertFalse(adapter.matches(model_other))
|
|
|
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
|
if dist.is_available() and dist.is_initialized():
|
|
self.skipTest(
|
|
"Cannot safely test deferred EP group resolution when a process group is already initialized"
|
|
)
|
|
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertIsNotNone(plugin._shim)
|
|
self.assertIsNone(plugin._shim.ep_group)
|
|
|
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
|
tmp_init.close()
|
|
init_method = f"file://{tmp_init.name}"
|
|
dist.init_process_group(
|
|
backend="gloo", init_method=init_method, world_size=1, rank=0
|
|
)
|
|
try:
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
_run_callback(
|
|
plugin,
|
|
cfg,
|
|
args=SimpleNamespace(logging_steps=1),
|
|
state=SimpleNamespace(global_step=1, log_history=[]),
|
|
control=SimpleNamespace(
|
|
should_log=False,
|
|
should_evaluate=False,
|
|
should_save=False,
|
|
should_training_stop=False,
|
|
),
|
|
)
|
|
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
|
finally:
|
|
dist.destroy_process_group()
|
|
os.unlink(tmp_init.name)
|
|
|
|
def test_telemetry_logging(self):
|
|
model, layer = _build_mixtral_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
hidden_dim = layer.config.hidden_size
|
|
hidden = torch.randn(2, 3, hidden_dim)
|
|
layer(hidden)
|
|
|
|
args = SimpleNamespace(logging_steps=1)
|
|
state = SimpleNamespace(global_step=1, log_history=[])
|
|
control = SimpleNamespace(
|
|
should_log=False,
|
|
should_evaluate=False,
|
|
should_save=False,
|
|
should_training_stop=False,
|
|
)
|
|
_run_callback(plugin, cfg, args=args, state=state, control=control)
|
|
|
|
self.assertTrue(control.should_log)
|
|
self.assertTrue(state.log_history)
|
|
telemetry = state.log_history[-1]
|
|
self.assertEqual(telemetry["step"], state.global_step)
|
|
self.assertIn("moe_afb/l0_load_min", telemetry)
|
|
self.assertIn("moe_afb/l0_load_max", telemetry)
|
|
self.assertIn("moe_afb/l0_bias_abs_max", telemetry)
|
|
|
|
def test_get_num_experts_v5_attribute_paths(self):
|
|
"""Verify get_num_experts works with v5 attribute layout where
|
|
num_experts is on gate/experts sub-modules, not the block."""
|
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
|
|
|
adapter = MixtralAdapter()
|
|
|
|
# Simulates v5 MixtralSparseMoeBlock (num_experts on gate, not block)
|
|
block = SimpleNamespace(
|
|
gate=SimpleNamespace(num_experts=8),
|
|
experts=SimpleNamespace(num_experts=8),
|
|
)
|
|
self.assertEqual(adapter.get_num_experts(block), 8)
|
|
|
|
# Also works when num_experts is directly on block
|
|
block2 = SimpleNamespace(num_experts=4)
|
|
self.assertEqual(adapter.get_num_experts(block2), 4)
|
|
|
|
|
|
class TestAuxFreeKernelComposition(unittest.TestCase):
|
|
"""Tests that aux-free bias composes correctly with kernel routing."""
|
|
|
|
def test_sonicmoe_softmax_routing_with_afb_bias(self):
|
|
"""SonicMoE softmax routing should use biased selection / unbiased weights."""
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
num_experts = 4
|
|
top_k = 2
|
|
hidden_dim = 16
|
|
T = 6
|
|
|
|
# Build a mock MoE block with gate attributes
|
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
|
gate.top_k = top_k
|
|
gate.num_experts = num_experts
|
|
gate.norm_topk_prob = True
|
|
|
|
moe_block = SimpleNamespace(gate=gate)
|
|
hidden = torch.randn(T, hidden_dim)
|
|
|
|
# Baseline: no bias
|
|
scores_base, tok_base, exp_base, logits_base = softmax_topk_routing(
|
|
hidden, moe_block
|
|
)
|
|
self.assertEqual(scores_base.shape[0], T * top_k)
|
|
|
|
# Now register aux-free buffers and set heavy bias on expert 0
|
|
moe_block._afb_bias = torch.zeros(num_experts)
|
|
moe_block._afb_bias[0] = 100.0
|
|
moe_block._afb_counts = torch.zeros(num_experts)
|
|
|
|
scores_biased, tok_biased, exp_biased, logits_biased = softmax_topk_routing(
|
|
hidden, moe_block
|
|
)
|
|
|
|
# Expert 0 should be selected for every token
|
|
self.assertTrue(
|
|
(exp_biased == 0).any(),
|
|
"Expert 0 should appear in selections when heavily biased",
|
|
)
|
|
# Counts should have been accumulated
|
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
|
# Total counts should equal T * top_k
|
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
|
|
|
def test_sonicmoe_routing_without_bias_unchanged(self):
|
|
"""Without _afb_bias, routing should produce identical results."""
|
|
from axolotl.integrations.kernels.sonicmoe.routing import softmax_topk_routing
|
|
|
|
num_experts = 4
|
|
top_k = 2
|
|
hidden_dim = 16
|
|
|
|
gate = nn.Linear(hidden_dim, num_experts, bias=False)
|
|
gate.top_k = top_k
|
|
gate.num_experts = num_experts
|
|
gate.norm_topk_prob = True
|
|
|
|
moe_block = SimpleNamespace(gate=gate)
|
|
hidden = torch.randn(4, hidden_dim)
|
|
|
|
# Without _afb_bias attribute
|
|
scores1, _, exp1, _ = softmax_topk_routing(hidden, moe_block)
|
|
|
|
# With _afb_bias = zeros (should be equivalent)
|
|
moe_block._afb_bias = torch.zeros(num_experts)
|
|
moe_block._afb_counts = torch.zeros(num_experts)
|
|
scores2, _, exp2, _ = softmax_topk_routing(hidden, moe_block)
|
|
|
|
torch.testing.assert_close(scores1, scores2)
|
|
torch.testing.assert_close(exp1, exp2)
|
|
|
|
@unittest.skipUnless(
|
|
importlib_util.find_spec("triton") is not None,
|
|
"triton not installed (required by scattermoe)",
|
|
)
|
|
def test_scattermoe_softmax_routing_with_afb_bias(self):
|
|
"""ScatterMoE softmax routing should use biased selection / unbiased weights."""
|
|
from axolotl.integrations.kernels.libs.scattermoe_lora.layers import (
|
|
_softmax_topk_route,
|
|
)
|
|
|
|
num_experts = 4
|
|
top_k = 2
|
|
hidden_dim = 16
|
|
T = 6
|
|
|
|
gate_weight = torch.randn(num_experts, hidden_dim)
|
|
base_gate = SimpleNamespace(
|
|
top_k=top_k,
|
|
num_experts=num_experts,
|
|
norm_topk_prob=True,
|
|
weight=gate_weight,
|
|
)
|
|
|
|
moe_block = SimpleNamespace()
|
|
hidden = torch.randn(T, hidden_dim)
|
|
|
|
# Baseline without bias
|
|
w_base, e_base, _, _ = _softmax_topk_route(
|
|
moe_block, base_gate, hidden, gate_weight, None
|
|
)
|
|
|
|
# With heavy bias on expert 0
|
|
moe_block._afb_bias = torch.zeros(num_experts)
|
|
moe_block._afb_bias[0] = 100.0
|
|
moe_block._afb_counts = torch.zeros(num_experts)
|
|
|
|
w_biased, e_biased, _, _ = _softmax_topk_route(
|
|
moe_block, base_gate, hidden, gate_weight, None
|
|
)
|
|
|
|
# Expert 0 should appear in all selections
|
|
self.assertTrue((e_biased == 0).any())
|
|
# Counts accumulated
|
|
self.assertGreater(moe_block._afb_counts[0].item(), 0)
|
|
self.assertEqual(int(moe_block._afb_counts.sum().item()), T * top_k)
|
|
|
|
def test_kernel_routing_skips_router_patch(self):
|
|
"""When a kernel backend has patched the block class, the adapter
|
|
should skip patching the router (buffers are still registered)."""
|
|
from axolotl.integrations.aux_free_router.adapters import MixtralAdapter
|
|
|
|
adapter = MixtralAdapter()
|
|
|
|
# Create a mock layer whose class has _original_forward (SonicMoE marker)
|
|
class PatchedBlock(nn.Module):
|
|
_original_forward = True # SonicMoE marker
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = nn.Linear(16, 4, bias=False)
|
|
self.gate.top_k = 2
|
|
self.gate.num_experts = 4
|
|
self.gate.hidden_dim = 16
|
|
self.experts = nn.Linear(16, 16) # placeholder
|
|
|
|
layer = PatchedBlock()
|
|
self.assertTrue(adapter.uses_kernel_routing(layer))
|
|
|
|
# Gate should NOT be patched (kernel handles routing)
|
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
|
|
|
def test_adapter_buffers_registered_even_with_kernel(self):
|
|
"""Even when kernel routing is active, aux-free buffers must be
|
|
registered on the MoE block so the kernel routing can find them."""
|
|
from axolotl.integrations.aux_free_router.adapters import (
|
|
LayerHandle,
|
|
MixtralAdapter,
|
|
)
|
|
from axolotl.integrations.aux_free_router.core import (
|
|
AuxFreeConfig,
|
|
AuxFreeShim,
|
|
AuxFreeState,
|
|
)
|
|
|
|
class PatchedBlock(nn.Module):
|
|
_original_forward = True
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gate = nn.Linear(16, 4, bias=False)
|
|
self.gate.top_k = 2
|
|
self.gate.num_experts = 4
|
|
self.gate.hidden_dim = 16
|
|
self.experts = nn.Linear(16, 16)
|
|
|
|
layer = PatchedBlock()
|
|
adapter = MixtralAdapter()
|
|
cfg = AuxFreeConfig()
|
|
state = AuxFreeState(
|
|
num_layers=1, num_experts=4, device=torch.device("cpu"), cfg=cfg
|
|
)
|
|
shim = AuxFreeShim(state=state)
|
|
handle = LayerHandle(layer=layer, layer_idx=0, num_experts=4, top_k=2)
|
|
|
|
adapter.prepare(layer, handle, shim)
|
|
|
|
# Buffers should be registered for kernel routing to use
|
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
|
self.assertTrue(hasattr(layer, "_afb_counts"))
|
|
self.assertTrue(hasattr(layer, "_afb_ema"))
|
|
# But gate should NOT be patched
|
|
self.assertFalse(getattr(layer.gate, "_afb_patched", False))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|