Add ring/llama4 aux-free adapters and EP sync support

This commit is contained in:
lhl
2025-10-27 14:42:36 +09:00
committed by Wing Lian
parent 3e4688289c
commit 2af7475fdf
7 changed files with 331 additions and 13 deletions

View File

@@ -19,7 +19,8 @@ Enable
moe_update_momentum: 0.9 # default if unset
moe_bias_cap: 2.0 # default if unset
moe_afb_warmup_steps: 100 # optional
moe_bias_sync_group: world # or 'ep' if expert-parallel is configured
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
Config keys
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
@@ -28,9 +29,10 @@ Config keys
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE. Jamba optional.
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
Notes

View File

@@ -50,12 +50,7 @@ class BaseMoEAdapter:
except Exception: # pragma: no cover - non-critical
pass
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled.
Note: Forward rebind happens in concrete adapters once we implement full routing.
For now, we only attach buffers as placeholders to minimize disturbance.
"""
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
@@ -65,6 +60,10 @@ class BaseMoEAdapter:
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled."""
self._register_aux_buffers(moe_layer, handle)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
@@ -136,6 +135,102 @@ class Qwen3Adapter(MixtralAdapter):
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
class BailingAdapter(BaseMoEAdapter):
family = "bailing_moe"
def matches(self, model: nn.Module) -> bool:
model_type = getattr(getattr(model, "config", object()), "model_type", "")
return model_type in ("bailing_moe", "bailing_moe_v2")
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
yield m
def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"):
return int(getattr(moe_layer, "num_experts"))
cfg = getattr(moe_layer, "config", None)
return int(getattr(cfg, "num_experts"))
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle)
self._patch_bailing_gate(moe_layer)
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
def afb_gate_forward(self, hidden_states: torch.Tensor):
flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = getattr(moe_layer, "_afb_bias")
biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
weights = weights / denom
weights = weights * self.routed_scaling_factor
flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
setattr(gate, "_afb_patched", True)
class Llama4Adapter(BaseMoEAdapter):
family = "llama4"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "Llama4TextMoe":
yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle)
self._patch_llama4_router(moe_layer)
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
router = getattr(moe_layer, "router", None)
if router is None:
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
return
if getattr(router, "_afb_patched", False):
return
def afb_router_forward(self, hidden_states: torch.Tensor):
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
router_logits = F.linear(flat, self.weight, self.bias)
bias = getattr(moe_layer, "_afb_bias")
biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices)
router_scores = torch.full_like(router_logits, float("-inf"))
router_scores.scatter_(1, router_indices, unbiased_top)
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
setattr(router, "_afb_patched", True)
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.

View File

@@ -9,13 +9,20 @@ from __future__ import annotations
from typing import Optional
import torch
import torch.distributed as dist
from transformers.trainer_callback import TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_distributed
from axolotl.utils.logging import get_logger
from .adapters import BaseMoEAdapter, MixtralAdapter, Qwen3Adapter, discover_and_prepare_layers
from .adapters import (
BailingAdapter,
BaseMoEAdapter,
Llama4Adapter,
MixtralAdapter,
Qwen3Adapter,
discover_and_prepare_layers,
)
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
LOG = get_logger(__name__)
@@ -70,6 +77,7 @@ class AuxFreeMoEPlugin(BasePlugin):
super().__init__()
self._handles: list = []
self._shim: Optional[AuxFreeShim] = None
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
@@ -98,7 +106,12 @@ class AuxFreeMoEPlugin(BasePlugin):
)
# Discover layers to count the number and experts for state sizing
adapters: list[BaseMoEAdapter] = [MixtralAdapter(), Qwen3Adapter()]
adapters: list[BaseMoEAdapter] = [
MixtralAdapter(),
Qwen3Adapter(),
BailingAdapter(),
Llama4Adapter(),
]
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
@@ -112,7 +125,8 @@ class AuxFreeMoEPlugin(BasePlugin):
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
self._shim = AuxFreeShim(state=state, ep_group=None)
ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None
self._shim = AuxFreeShim(state=state, ep_group=ep_group)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
@@ -121,6 +135,32 @@ class AuxFreeMoEPlugin(BasePlugin):
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
)
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized():
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
return None
ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1:
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
return None
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
ep_size,
world,
)
return None
if ep_size == world:
return dist.group.WORLD
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
return self._ep_group_cache[ranks]
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return []

View File

@@ -874,6 +874,12 @@ class AxolotlInputConfig(
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
expert_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={

View File

@@ -1386,6 +1386,14 @@ class ComplexValidationMixin:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after")
def check_expert_parallel_size(self):
if not getattr(self, "expert_parallel_size", None):
self.expert_parallel_size = 1
elif self.expert_parallel_size < 1:
raise ValueError("expert_parallel_size must be >= 1")
return self
@model_validator(mode="after")
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:

View File

@@ -12,7 +12,10 @@ from pathlib import Path
import torch
from packaging import version
from tbparse import SummaryReader
try:
from tbparse import SummaryReader
except ImportError: # pragma: no cover - optional dependency
SummaryReader = None
from axolotl.utils.dict import DictDefault
@@ -185,6 +188,8 @@ def check_tensorboard(
"""
helper function to parse and check tensorboard logs
"""
if SummaryReader is None:
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)

View File

@@ -0,0 +1,162 @@
import sys
import unittest
from types import SimpleNamespace
import torch
import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
def _cfg(**overrides):
defaults = dict(
moe_balance_type="noaux_tc",
moe_update_rate=0.1,
moe_update_momentum=0.9,
moe_bias_cap=2.0,
moe_afb_warmup_steps=0,
moe_bias_sync_group="world",
expert_parallel_size=1,
)
defaults.update(overrides)
return SimpleNamespace(**defaults)
def _load_bailing_modules():
repo_dir = snapshot_download(
repo_id="inclusionAI/Ring-mini-2.0",
allow_patterns=[
"configuration_bailing_moe_v2.py",
"modeling_bailing_moe_v2.py",
"__init__.py",
],
)
repo = Path(repo_dir)
config_path = repo / "configuration_bailing_moe_v2.py"
modeling_path = repo / "modeling_bailing_moe_v2.py"
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
if config_name not in sys.modules:
spec = importlib_util.spec_from_file_location(config_name, config_path)
module = importlib_util.module_from_spec(spec)
sys.modules[config_name] = module
sys.modules["configuration_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
config_module = sys.modules[config_name]
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
if modeling_name not in sys.modules:
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
module = importlib_util.module_from_spec(spec)
sys.modules[modeling_name] = module
sys.modules["modeling_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
modeling_module = sys.modules[modeling_name]
BailingMoeV2Config = config_module.BailingMoeV2Config
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
def _build_bailing_model():
BailingConfig, BailingBlock = _load_bailing_modules()
config = BailingConfig(
hidden_size=16,
intermediate_size=32,
moe_intermediate_size=32,
num_experts=4,
num_shared_experts=None,
num_experts_per_tok=2,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
)
block = BailingBlock(config)
class DummyModel(nn.Module):
def __init__(self, layer):
super().__init__()
self.block = layer
self.config = SimpleNamespace(model_type="bailing_moe")
def forward(self, hidden_states):
return self.block(hidden_states)
return DummyModel(block), block
def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
)
layer = Llama4TextMoe(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="llama4")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
dummy = SimpleNamespace()
callback.on_step_end(args=dummy, state=dummy, control=dummy)
class TestAuxFreeAdapters(unittest.TestCase):
def test_bailing_adapter_updates_counts_and_bias(self):
model, block = _build_bailing_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(block, "_afb_bias"))
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(layer, "_afb_bias"))
hidden = torch.randn(2, 4, layer.hidden_dim)
layer(hidden)
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
if __name__ == "__main__":
unittest.main()