268 lines
8.8 KiB
Python
268 lines
8.8 KiB
Python
import os
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from types import SimpleNamespace
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from importlib import util as importlib_util
|
|
from pathlib import Path
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
|
|
|
|
|
def _cfg(**overrides):
|
|
defaults = dict(
|
|
moe_balance_type="noaux_tc",
|
|
moe_update_rate=0.1,
|
|
moe_update_momentum=0.9,
|
|
moe_bias_cap=2.0,
|
|
moe_afb_warmup_steps=0,
|
|
moe_bias_sync_group="world",
|
|
expert_parallel_size=1,
|
|
)
|
|
defaults.update(overrides)
|
|
return SimpleNamespace(**defaults)
|
|
|
|
|
|
def _load_bailing_modules():
|
|
repo_dir = snapshot_download(
|
|
repo_id="inclusionAI/Ring-mini-2.0",
|
|
allow_patterns=[
|
|
"configuration_bailing_moe_v2.py",
|
|
"modeling_bailing_moe_v2.py",
|
|
"__init__.py",
|
|
],
|
|
)
|
|
repo = Path(repo_dir)
|
|
config_path = repo / "configuration_bailing_moe_v2.py"
|
|
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
|
|
|
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
|
if config_name not in sys.modules:
|
|
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
|
module = importlib_util.module_from_spec(spec)
|
|
sys.modules[config_name] = module
|
|
sys.modules["configuration_bailing_moe_v2"] = module
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
config_module = sys.modules[config_name]
|
|
|
|
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
|
if modeling_name not in sys.modules:
|
|
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
|
module = importlib_util.module_from_spec(spec)
|
|
sys.modules[modeling_name] = module
|
|
sys.modules["modeling_bailing_moe_v2"] = module
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
modeling_module = sys.modules[modeling_name]
|
|
|
|
BailingMoeV2Config = config_module.BailingMoeV2Config
|
|
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
|
|
|
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
|
|
|
|
|
def _build_bailing_model():
|
|
BailingConfig, BailingBlock = _load_bailing_modules()
|
|
config = BailingConfig(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
moe_intermediate_size=32,
|
|
num_experts=4,
|
|
num_shared_experts=None,
|
|
num_experts_per_tok=2,
|
|
n_group=1,
|
|
topk_group=1,
|
|
routed_scaling_factor=1.0,
|
|
)
|
|
block = BailingBlock(config)
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, layer):
|
|
super().__init__()
|
|
self.block = layer
|
|
self.config = SimpleNamespace(model_type="bailing_moe")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.block(hidden_states)
|
|
|
|
return DummyModel(block), block
|
|
|
|
|
|
def _build_llama4_model():
|
|
from transformers import Llama4TextConfig
|
|
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
|
|
|
config = Llama4TextConfig(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
num_local_experts=4,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
num_experts_per_tok=2,
|
|
)
|
|
layer = Llama4TextMoe(config)
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, moe_layer):
|
|
super().__init__()
|
|
self.moe = moe_layer
|
|
self.config = SimpleNamespace(model_type="llama4")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.moe(hidden_states)
|
|
|
|
return DummyModel(layer), layer
|
|
|
|
|
|
def _build_mixtral_model():
|
|
from transformers import MixtralConfig
|
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
|
|
config = MixtralConfig(
|
|
hidden_size=16,
|
|
intermediate_size=32,
|
|
num_local_experts=4,
|
|
num_experts_per_tok=2,
|
|
num_attention_heads=2,
|
|
num_key_value_heads=2,
|
|
)
|
|
layer = MixtralSparseMoeBlock(config)
|
|
layer.config = config
|
|
|
|
class DummyModel(nn.Module):
|
|
def __init__(self, moe_layer):
|
|
super().__init__()
|
|
self.moe = moe_layer
|
|
self.config = SimpleNamespace(model_type="mixtral")
|
|
|
|
def forward(self, hidden_states):
|
|
return self.moe(hidden_states)
|
|
|
|
return DummyModel(layer), layer
|
|
|
|
|
|
def _run_callback(plugin, cfg):
|
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
assert callbacks, "expected aux-free callback to be registered"
|
|
callback = callbacks[0]
|
|
dummy = SimpleNamespace()
|
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
|
|
|
|
class TestAuxFreeAdapters(unittest.TestCase):
|
|
def test_bailing_adapter_updates_counts_and_bias(self):
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertTrue(hasattr(block, "_afb_bias"))
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
|
|
|
_run_callback(plugin, cfg)
|
|
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
|
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
|
|
|
def test_llama4_adapter_biases_router_selection(self):
|
|
model, layer = _build_llama4_model()
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertTrue(hasattr(layer, "_afb_bias"))
|
|
hidden = torch.randn(2, 4, layer.hidden_dim)
|
|
layer(hidden)
|
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
|
|
_run_callback(plugin, cfg)
|
|
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
|
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
|
|
|
def test_bias_warmup_respected(self):
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg(moe_afb_warmup_steps=2)
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
self.assertTrue(callbacks)
|
|
callback = callbacks[0]
|
|
dummy = SimpleNamespace()
|
|
|
|
def _step():
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
|
|
# Warmup steps should leave bias untouched.
|
|
_step()
|
|
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
|
|
|
_step()
|
|
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
|
|
|
# Third step exceeds warmup -> bias should update.
|
|
_step()
|
|
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
|
|
|
def test_mixtral_adapter_respects_native_forward(self):
|
|
model, layer = _build_mixtral_model()
|
|
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
|
|
|
hidden_dim = layer.config.hidden_size
|
|
hidden = torch.randn(2, 3, hidden_dim)
|
|
baseline_out, baseline_logits = layer(hidden.clone())
|
|
|
|
cfg = _cfg()
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
patched_out, patched_logits = layer(hidden.clone())
|
|
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
|
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
|
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
_run_callback(plugin, cfg)
|
|
|
|
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
|
if dist.is_available() and dist.is_initialized():
|
|
dist.destroy_process_group()
|
|
|
|
model, block = _build_bailing_model()
|
|
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
|
plugin = AuxFreeMoEPlugin()
|
|
plugin.post_model_build(cfg, model)
|
|
|
|
self.assertIsNotNone(plugin._shim)
|
|
self.assertIsNone(plugin._shim.ep_group)
|
|
|
|
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
self.assertTrue(callbacks)
|
|
callback = callbacks[0]
|
|
dummy = SimpleNamespace()
|
|
|
|
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
|
tmp_init.close()
|
|
init_method = f"file://{tmp_init.name}"
|
|
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
|
try:
|
|
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
block(hidden)
|
|
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
|
finally:
|
|
dist.destroy_process_group()
|
|
os.unlink(tmp_init.name)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|