Add ring/llama4 aux-free adapters and EP sync support
This commit is contained in:
@@ -19,7 +19,8 @@ Enable
|
||||
moe_update_momentum: 0.9 # default if unset
|
||||
moe_bias_cap: 2.0 # default if unset
|
||||
moe_afb_warmup_steps: 100 # optional
|
||||
moe_bias_sync_group: world # or 'ep' if expert-parallel is configured
|
||||
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||
|
||||
Config keys
|
||||
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||
@@ -28,9 +29,10 @@ Config keys
|
||||
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE. Jamba optional.
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||
|
||||
Notes
|
||||
|
||||
@@ -50,12 +50,7 @@ class BaseMoEAdapter:
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
pass
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
"""Attach per-layer buffers and mark as aux-free enabled.
|
||||
|
||||
Note: Forward rebind happens in concrete adapters once we implement full routing.
|
||||
For now, we only attach buffers as placeholders to minimize disturbance.
|
||||
"""
|
||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None:
|
||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
||||
@@ -65,6 +60,10 @@ class BaseMoEAdapter:
|
||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._patch_forward_with_aux_free(moe_layer)
|
||||
|
||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
||||
@@ -136,6 +135,102 @@ class Qwen3Adapter(MixtralAdapter):
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
||||
|
||||
|
||||
class BailingAdapter(BaseMoEAdapter):
|
||||
family = "bailing_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
model_type = getattr(getattr(model, "config", object()), "model_type", "")
|
||||
return model_type in ("bailing_moe", "bailing_moe_v2")
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||
yield m
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
if hasattr(moe_layer, "num_experts"):
|
||||
return int(getattr(moe_layer, "num_experts"))
|
||||
cfg = getattr(moe_layer, "config", None)
|
||||
return int(getattr(cfg, "num_experts"))
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
logits = F.linear(flat.float(), self.weight.float())
|
||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
biased_scores = scores_unbiased + bias
|
||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||
if self.top_k > 1:
|
||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||
weights = weights / denom
|
||||
weights = weights * self.routed_scaling_factor
|
||||
|
||||
flat_topk = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||
|
||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
setattr(gate, "_afb_patched", True)
|
||||
|
||||
|
||||
class Llama4Adapter(BaseMoEAdapter):
|
||||
family = "llama4"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "Llama4TextMoe":
|
||||
yield m
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||
router = getattr(moe_layer, "router", None)
|
||||
if router is None:
|
||||
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||
return
|
||||
if getattr(router, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
router_logits = F.linear(flat, self.weight, self.bias)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
biased_logits = router_logits + bias
|
||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||
|
||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return router_scores, router_logits
|
||||
|
||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||
setattr(router, "_afb_patched", True)
|
||||
|
||||
|
||||
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||
|
||||
|
||||
@@ -9,13 +9,20 @@ from __future__ import annotations
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.distributed import is_distributed
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .adapters import BaseMoEAdapter, MixtralAdapter, Qwen3Adapter, discover_and_prepare_layers
|
||||
from .adapters import (
|
||||
BailingAdapter,
|
||||
BaseMoEAdapter,
|
||||
Llama4Adapter,
|
||||
MixtralAdapter,
|
||||
Qwen3Adapter,
|
||||
discover_and_prepare_layers,
|
||||
)
|
||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -70,6 +77,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
super().__init__()
|
||||
self._handles: list = []
|
||||
self._shim: Optional[AuxFreeShim] = None
|
||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
# Enable only when explicitly requested
|
||||
@@ -98,7 +106,12 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
adapters: list[BaseMoEAdapter] = [MixtralAdapter(), Qwen3Adapter()]
|
||||
adapters: list[BaseMoEAdapter] = [
|
||||
MixtralAdapter(),
|
||||
Qwen3Adapter(),
|
||||
BailingAdapter(),
|
||||
Llama4Adapter(),
|
||||
]
|
||||
|
||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||
n_layers = 0
|
||||
@@ -112,7 +125,8 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
||||
self._shim = AuxFreeShim(state=state, ep_group=None)
|
||||
ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group)
|
||||
|
||||
# Discover and prepare layers (attach per-layer buffers)
|
||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||
@@ -121,6 +135,32 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||
)
|
||||
|
||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
||||
return None
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
||||
return None
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
return None
|
||||
if ep_size == world:
|
||||
return dist.group.WORLD
|
||||
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
if ranks not in self._ep_group_cache:
|
||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||
return self._ep_group_cache[ranks]
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return []
|
||||
|
||||
@@ -874,6 +874,12 @@ class AxolotlInputConfig(
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
expert_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -1386,6 +1386,14 @@ class ComplexValidationMixin:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_expert_parallel_size(self):
|
||||
if not getattr(self, "expert_parallel_size", None):
|
||||
self.expert_parallel_size = 1
|
||||
elif self.expert_parallel_size < 1:
|
||||
raise ValueError("expert_parallel_size must be >= 1")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
|
||||
@@ -12,7 +12,10 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
try:
|
||||
from tbparse import SummaryReader
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
SummaryReader = None
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -185,6 +188,8 @@ def check_tensorboard(
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
"""
|
||||
if SummaryReader is None:
|
||||
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
|
||||
162
tests/unit/test_aux_free_adapters.py
Normal file
162
tests/unit/test_aux_free_adapters.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import sys
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
|
||||
|
||||
def _cfg(**overrides):
|
||||
defaults = dict(
|
||||
moe_balance_type="noaux_tc",
|
||||
moe_update_rate=0.1,
|
||||
moe_update_momentum=0.9,
|
||||
moe_bias_cap=2.0,
|
||||
moe_afb_warmup_steps=0,
|
||||
moe_bias_sync_group="world",
|
||||
expert_parallel_size=1,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
def _load_bailing_modules():
|
||||
repo_dir = snapshot_download(
|
||||
repo_id="inclusionAI/Ring-mini-2.0",
|
||||
allow_patterns=[
|
||||
"configuration_bailing_moe_v2.py",
|
||||
"modeling_bailing_moe_v2.py",
|
||||
"__init__.py",
|
||||
],
|
||||
)
|
||||
repo = Path(repo_dir)
|
||||
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||
|
||||
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||
if config_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[config_name] = module
|
||||
sys.modules["configuration_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
config_module = sys.modules[config_name]
|
||||
|
||||
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||
if modeling_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[modeling_name] = module
|
||||
sys.modules["modeling_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
modeling_module = sys.modules[modeling_name]
|
||||
|
||||
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||
|
||||
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||
|
||||
|
||||
def _build_bailing_model():
|
||||
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||
config = BailingConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
moe_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_shared_experts=None,
|
||||
num_experts_per_tok=2,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
block = BailingBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.block = layer
|
||||
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.block(hidden_states)
|
||||
|
||||
return DummyModel(block), block
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
config = Llama4TextConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="llama4")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||
layer(hidden)
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user