diff --git a/src/axolotl/integrations/aux_free_router/README.md b/src/axolotl/integrations/aux_free_router/README.md index 7776ff94e..f254e253e 100644 --- a/src/axolotl/integrations/aux_free_router/README.md +++ b/src/axolotl/integrations/aux_free_router/README.md @@ -19,7 +19,8 @@ Enable moe_update_momentum: 0.9 # default if unset moe_bias_cap: 2.0 # default if unset moe_afb_warmup_steps: 100 # optional - moe_bias_sync_group: world # or 'ep' if expert-parallel is configured + moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1 + expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep Config keys - moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native. @@ -28,9 +29,10 @@ Config keys - moe_bias_cap: absolute clamp for bias. Default: 2.0. - moe_afb_warmup_steps: delay before applying updates. Default: 0. - moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world. +- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world). Compatibility -- Targeted families: Mixtral, Qwen3-MoE. Jamba optional. +- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers. - Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future. Notes diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index 014c3b80d..cbcf84618 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -50,12 +50,7 @@ class BaseMoEAdapter: except Exception: # pragma: no cover - non-critical pass - def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: - """Attach per-layer buffers and mark as aux-free enabled. - - Note: Forward rebind happens in concrete adapters once we implement full routing. - For now, we only attach buffers as placeholders to minimize disturbance. - """ + def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle) -> None: device = next(moe_layer.parameters(), torch.tensor(0)).device if not hasattr(moe_layer, "_afb_bias"): moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device)) @@ -65,6 +60,10 @@ class BaseMoEAdapter: moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device)) moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined] moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined] + + def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + """Attach per-layer buffers and mark as aux-free enabled.""" + self._register_aux_buffers(moe_layer, handle) self._patch_forward_with_aux_free(moe_layer) def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None: @@ -136,6 +135,102 @@ class Qwen3Adapter(MixtralAdapter): return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe") +class BailingAdapter(BaseMoEAdapter): + family = "bailing_moe" + + def matches(self, model: nn.Module) -> bool: + model_type = getattr(getattr(model, "config", object()), "model_type", "") + return model_type in ("bailing_moe", "bailing_moe_v2") + + def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: + for m in model.modules(): + if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock": + yield m + + def get_num_experts(self, moe_layer: nn.Module) -> int: + if hasattr(moe_layer, "num_experts"): + return int(getattr(moe_layer, "num_experts")) + cfg = getattr(moe_layer, "config", None) + return int(getattr(cfg, "num_experts")) + + def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + self._register_aux_buffers(moe_layer, handle) + self._patch_bailing_gate(moe_layer) + + def _patch_bailing_gate(self, moe_layer: nn.Module) -> None: + gate = getattr(moe_layer, "gate", None) + if gate is None: + LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch") + return + if getattr(gate, "_afb_patched", False): + return + + def afb_gate_forward(self, hidden_states: torch.Tensor): + flat = hidden_states.view(-1, hidden_states.shape[-1]) + logits = F.linear(flat.float(), self.weight.float()) + scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype) + bias = getattr(moe_layer, "_afb_bias") + biased_scores = scores_unbiased + bias + topk_vals, topk_idx = self.group_limited_topk(biased_scores) + weights = torch.gather(scores_unbiased, 1, topk_idx) + if self.top_k > 1: + denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20) + weights = weights / denom + weights = weights * self.routed_scaling_factor + + flat_topk = topk_idx.reshape(-1) + counts = torch.bincount(flat_topk, minlength=bias.numel()) + getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) + + return topk_idx, weights.to(hidden_states.dtype), logits + + gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined] + setattr(gate, "_afb_patched", True) + + +class Llama4Adapter(BaseMoEAdapter): + family = "llama4" + + def matches(self, model: nn.Module) -> bool: + return getattr(getattr(model, "config", object()), "model_type", "") == "llama4" + + def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: + for m in model.modules(): + if m.__class__.__name__ == "Llama4TextMoe": + yield m + + def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None: + self._register_aux_buffers(moe_layer, handle) + self._patch_llama4_router(moe_layer) + + def _patch_llama4_router(self, moe_layer: nn.Module) -> None: + router = getattr(moe_layer, "router", None) + if router is None: + LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch") + return + if getattr(router, "_afb_patched", False): + return + + def afb_router_forward(self, hidden_states: torch.Tensor): + flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1]) + router_logits = F.linear(flat, self.weight, self.bias) + bias = getattr(moe_layer, "_afb_bias") + biased_logits = router_logits + bias + _, router_indices = torch.topk(biased_logits, self.top_k, dim=1) + unbiased_top = torch.gather(router_logits, 1, router_indices) + router_scores = torch.full_like(router_logits, float("-inf")) + router_scores.scatter_(1, router_indices, unbiased_top) + router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype) + + counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel()) + getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype)) + + return router_scores, router_logits + + router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined] + setattr(router, "_afb_patched", True) + + def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]: """Discover MoE layers using the first matching adapter and attach per-layer buffers. diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index 3cd6b4f6c..6f1c2a633 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -9,13 +9,20 @@ from __future__ import annotations from typing import Optional import torch +import torch.distributed as dist from transformers.trainer_callback import TrainerCallback from axolotl.integrations.base import BasePlugin -from axolotl.utils.distributed import is_distributed from axolotl.utils.logging import get_logger -from .adapters import BaseMoEAdapter, MixtralAdapter, Qwen3Adapter, discover_and_prepare_layers +from .adapters import ( + BailingAdapter, + BaseMoEAdapter, + Llama4Adapter, + MixtralAdapter, + Qwen3Adapter, + discover_and_prepare_layers, +) from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState LOG = get_logger(__name__) @@ -70,6 +77,7 @@ class AuxFreeMoEPlugin(BasePlugin): super().__init__() self._handles: list = [] self._shim: Optional[AuxFreeShim] = None + self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {} def post_model_build(self, cfg, model): # Enable only when explicitly requested @@ -98,7 +106,12 @@ class AuxFreeMoEPlugin(BasePlugin): ) # Discover layers to count the number and experts for state sizing - adapters: list[BaseMoEAdapter] = [MixtralAdapter(), Qwen3Adapter()] + adapters: list[BaseMoEAdapter] = [ + MixtralAdapter(), + Qwen3Adapter(), + BailingAdapter(), + Llama4Adapter(), + ] # For initial state sizing, we conservatively assume the first discovered layer defines nE n_layers = 0 @@ -112,7 +125,8 @@ class AuxFreeMoEPlugin(BasePlugin): # we'll set a minimal placeholder; prepare() will conceptually use module buffers instead n_experts = 2 state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg) - self._shim = AuxFreeShim(state=state, ep_group=None) + ep_group = self._resolve_ep_group(cfg) if sync_group == "ep" else None + self._shim = AuxFreeShim(state=state, ep_group=ep_group) # Discover and prepare layers (attach per-layer buffers) self._handles = discover_and_prepare_layers(model, adapters, self._shim) @@ -121,6 +135,32 @@ class AuxFreeMoEPlugin(BasePlugin): f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}" ) + def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]: + if not dist.is_available() or not dist.is_initialized(): + LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world") + return None + ep_size = getattr(cfg, "expert_parallel_size", None) + if not ep_size or ep_size <= 1: + LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world") + return None + world = dist.get_world_size() + if world % ep_size != 0: + LOG.warning( + "AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world", + ep_size, + world, + ) + return None + if ep_size == world: + return dist.group.WORLD + + rank = dist.get_rank() + group_start = (rank // ep_size) * ep_size + ranks = tuple(range(group_start, group_start + ep_size)) + if ranks not in self._ep_group_cache: + self._ep_group_cache[ranks] = dist.new_group(ranks) + return self._ep_group_cache[ranks] + def add_callbacks_post_trainer(self, cfg, trainer): if getattr(cfg, "moe_balance_type", None) != "noaux_tc": return [] diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index bda30cf15..1ba657b5c 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -874,6 +874,12 @@ class AxolotlInputConfig( "description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP." }, ) + expert_parallel_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset." + }, + ) special_tokens: SpecialTokensConfig | None = Field( default=None, json_schema_extra={ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 8ff61b370..4304df579 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1386,6 +1386,14 @@ class ComplexValidationMixin: self.tensor_parallel_size = 1 return self + @model_validator(mode="after") + def check_expert_parallel_size(self): + if not getattr(self, "expert_parallel_size", None): + self.expert_parallel_size = 1 + elif self.expert_parallel_size < 1: + raise ValueError("expert_parallel_size must be >= 1") + return self + @model_validator(mode="after") def check_context_parallel_size(self): if self.sequence_parallel_degree and not self.context_parallel_size: diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 268311295..ef6ed3dfd 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -12,7 +12,10 @@ from pathlib import Path import torch from packaging import version -from tbparse import SummaryReader +try: + from tbparse import SummaryReader +except ImportError: # pragma: no cover - optional dependency + SummaryReader = None from axolotl.utils.dict import DictDefault @@ -185,6 +188,8 @@ def check_tensorboard( """ helper function to parse and check tensorboard logs """ + if SummaryReader is None: + raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions") tb_log_path = most_recent_subdir(temp_run_dir) event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) reader = SummaryReader(event_file) diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py new file mode 100644 index 000000000..ad7d91e48 --- /dev/null +++ b/tests/unit/test_aux_free_adapters.py @@ -0,0 +1,162 @@ +import sys +import unittest +from types import SimpleNamespace + +import torch +import torch.nn as nn +from importlib import util as importlib_util +from pathlib import Path + +from huggingface_hub import snapshot_download + +from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin + + +def _cfg(**overrides): + defaults = dict( + moe_balance_type="noaux_tc", + moe_update_rate=0.1, + moe_update_momentum=0.9, + moe_bias_cap=2.0, + moe_afb_warmup_steps=0, + moe_bias_sync_group="world", + expert_parallel_size=1, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +def _load_bailing_modules(): + repo_dir = snapshot_download( + repo_id="inclusionAI/Ring-mini-2.0", + allow_patterns=[ + "configuration_bailing_moe_v2.py", + "modeling_bailing_moe_v2.py", + "__init__.py", + ], + ) + repo = Path(repo_dir) + config_path = repo / "configuration_bailing_moe_v2.py" + modeling_path = repo / "modeling_bailing_moe_v2.py" + + config_name = "bailing_moe_v2.configuration_bailing_moe_v2" + if config_name not in sys.modules: + spec = importlib_util.spec_from_file_location(config_name, config_path) + module = importlib_util.module_from_spec(spec) + sys.modules[config_name] = module + sys.modules["configuration_bailing_moe_v2"] = module + assert spec.loader is not None + spec.loader.exec_module(module) + config_module = sys.modules[config_name] + + modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2" + if modeling_name not in sys.modules: + spec = importlib_util.spec_from_file_location(modeling_name, modeling_path) + module = importlib_util.module_from_spec(spec) + sys.modules[modeling_name] = module + sys.modules["modeling_bailing_moe_v2"] = module + assert spec.loader is not None + spec.loader.exec_module(module) + modeling_module = sys.modules[modeling_name] + + BailingMoeV2Config = config_module.BailingMoeV2Config + BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock + + return BailingMoeV2Config, BailingMoeV2SparseMoeBlock + + +def _build_bailing_model(): + BailingConfig, BailingBlock = _load_bailing_modules() + config = BailingConfig( + hidden_size=16, + intermediate_size=32, + moe_intermediate_size=32, + num_experts=4, + num_shared_experts=None, + num_experts_per_tok=2, + n_group=1, + topk_group=1, + routed_scaling_factor=1.0, + ) + block = BailingBlock(config) + + class DummyModel(nn.Module): + def __init__(self, layer): + super().__init__() + self.block = layer + self.config = SimpleNamespace(model_type="bailing_moe") + + def forward(self, hidden_states): + return self.block(hidden_states) + + return DummyModel(block), block + + +def _build_llama4_model(): + from transformers import Llama4TextConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = Llama4TextConfig( + hidden_size=16, + intermediate_size=32, + num_local_experts=4, + num_attention_heads=2, + num_key_value_heads=2, + num_experts_per_tok=2, + ) + layer = Llama4TextMoe(config) + + class DummyModel(nn.Module): + def __init__(self, moe_layer): + super().__init__() + self.moe = moe_layer + self.config = SimpleNamespace(model_type="llama4") + + def forward(self, hidden_states): + return self.moe(hidden_states) + + return DummyModel(layer), layer + + +def _run_callback(plugin, cfg): + callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) + assert callbacks, "expected aux-free callback to be registered" + callback = callbacks[0] + dummy = SimpleNamespace() + callback.on_step_end(args=dummy, state=dummy, control=dummy) + + +class TestAuxFreeAdapters(unittest.TestCase): + def test_bailing_adapter_updates_counts_and_bias(self): + model, block = _build_bailing_model() + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + self.assertTrue(hasattr(block, "_afb_bias")) + hidden = torch.randn(2, 3, block.config.hidden_size) + block(hidden) + self.assertGreater(torch.count_nonzero(block._afb_counts), 0) + + _run_callback(plugin, cfg) + self.assertEqual(torch.count_nonzero(block._afb_counts), 0) + self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema))) + + def test_llama4_adapter_biases_router_selection(self): + model, layer = _build_llama4_model() + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + self.assertTrue(hasattr(layer, "_afb_bias")) + hidden = torch.randn(2, 4, layer.hidden_dim) + layer(hidden) + self.assertGreater(torch.count_nonzero(layer._afb_counts), 0) + + _run_callback(plugin, cfg) + self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) + self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))) + + +if __name__ == "__main__": + unittest.main()