From 6636e5de7e18fc816d26e740ddc1c6bc80918ec4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 22 Mar 2026 17:23:12 +0000 Subject: [PATCH] address PR code review --- .../integrations/aux_free_router/__init__.py | 2 +- .../integrations/aux_free_router/adapters.py | 20 +++++++++++++++---- .../integrations/aux_free_router/args.py | 4 ++-- .../integrations/aux_free_router/core.py | 2 +- .../integrations/aux_free_router/plugin.py | 16 ++++++++++----- .../integrations/kernels/sonicmoe/routing.py | 10 ++++++++-- tests/e2e/test_llama4_moe_aux_free.py | 2 +- tests/e2e/test_moe_aux_parity.py | 8 ++++++++ tests/unit/test_aux_free_adapters.py | 4 +++- 9 files changed, 51 insertions(+), 17 deletions(-) diff --git a/src/axolotl/integrations/aux_free_router/__init__.py b/src/axolotl/integrations/aux_free_router/__init__.py index 8eac77224..4640b71b9 100644 --- a/src/axolotl/integrations/aux_free_router/__init__.py +++ b/src/axolotl/integrations/aux_free_router/__init__.py @@ -4,6 +4,6 @@ from .args import AuxFreeRouterArgs from .plugin import AuxFreeMoEPlugin __all__ = [ - "AuxFreeRouterArgs", "AuxFreeMoEPlugin", + "AuxFreeRouterArgs", ] diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index 0df987fe3..71dace555 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -88,12 +88,22 @@ class BaseMoEAdapter: try: model_or_layer.router_aux_loss_coef = 0.0 except Exception: # pragma: no cover - non-critical - pass + LOG.debug( + "disable_aux_loss: failed to set router_aux_loss_coef on %s", + type(model_or_layer).__name__, + exc_info=True, + ) def _register_aux_buffers( self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim ) -> None: - device = next(moe_layer.parameters(), torch.tensor(0)).device + p = next(moe_layer.parameters(), None) + b = next(moe_layer.buffers(), None) + device = ( + p.device + if p is not None + else (b.device if b is not None else torch.device("cpu")) + ) if not hasattr(moe_layer, "_afb_bias"): moe_layer.register_buffer( "_afb_bias", torch.zeros(handle.num_experts, device=device) @@ -275,7 +285,7 @@ class BailingAdapter(BaseMoEAdapter): scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype) bias = moe_layer._afb_bias biased_scores = scores_unbiased + bias - topk_vals, topk_idx = self.group_limited_topk(biased_scores) + _, topk_idx = self.group_limited_topk(biased_scores) weights = torch.gather(scores_unbiased, 1, topk_idx) if self.top_k > 1: denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20) @@ -376,6 +386,8 @@ def discover_and_prepare_layers( idx += 1 LOG.info( - f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing" + "AuxFreeMoE: prepared %d %s layers for aux-free routing", + len(handles), + adapter.family, ) return handles diff --git a/src/axolotl/integrations/aux_free_router/args.py b/src/axolotl/integrations/aux_free_router/args.py index 0f4cc4035..2a2b17dc7 100644 --- a/src/axolotl/integrations/aux_free_router/args.py +++ b/src/axolotl/integrations/aux_free_router/args.py @@ -37,14 +37,14 @@ class AuxFreeRouterArgs(BaseModel): moe_update_rate: float | None = Field( default=None, json_schema_extra={ - "description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. " + "description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. " "If unset, plugin default is 0.01." }, ) moe_update_momentum: float | None = Field( default=None, json_schema_extra={ - "description": "EMA momentum for expert load smoothing (0–1). " + "description": "EMA momentum for expert load smoothing (0-1). " "If unset, plugin default is 0.9." }, ) diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py index 34c9db2a8..a9c37d14e 100644 --- a/src/axolotl/integrations/aux_free_router/core.py +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -67,7 +67,7 @@ class AuxFreeShim: else: b = self.state.bias[layer_idx] biased = logits + b # bias is a buffer - topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1) + _topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1) chosen_logits = torch.gather(logits, -1, topk_idx) weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype) return topk_idx, weights diff --git a/src/axolotl/integrations/aux_free_router/plugin.py b/src/axolotl/integrations/aux_free_router/plugin.py index fd39c7dfc..fd56321c9 100644 --- a/src/axolotl/integrations/aux_free_router/plugin.py +++ b/src/axolotl/integrations/aux_free_router/plugin.py @@ -239,11 +239,17 @@ class AuxFreeMoEPlugin(BasePlugin): return dist.group.WORLD rank = dist.get_rank() - group_start = (rank // ep_size) * ep_size - ranks = tuple(range(group_start, group_start + ep_size)) - if ranks not in self._ep_group_cache: - self._ep_group_cache[ranks] = dist.new_group(ranks) - return self._ep_group_cache[ranks] + # All ranks must collectively create all EP subgroups in the same order + # to avoid deadlocks (dist.new_group is a collective operation). + world_size = world + my_group = None + for group_start in range(0, world_size, ep_size): + ranks = tuple(range(group_start, group_start + ep_size)) + if ranks not in self._ep_group_cache: + self._ep_group_cache[ranks] = dist.new_group(ranks) + if rank in ranks: + my_group = self._ep_group_cache[ranks] + return my_group def add_callbacks_post_trainer(self, cfg, trainer): if getattr(cfg, "moe_balance_type", None) != "noaux_tc": diff --git a/src/axolotl/integrations/kernels/sonicmoe/routing.py b/src/axolotl/integrations/kernels/sonicmoe/routing.py index 0b0da9fc8..49483e822 100644 --- a/src/axolotl/integrations/kernels/sonicmoe/routing.py +++ b/src/axolotl/integrations/kernels/sonicmoe/routing.py @@ -177,7 +177,9 @@ def softmax_group_topk_routing( score_mask = ( group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + scores_for_choice = scores_for_choice.masked_fill( + ~score_mask.bool(), -float("inf") + ) topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] topk_weights = router_probs.gather(1, topk_indices) @@ -275,7 +277,9 @@ def sigmoid_topk_routing( score_mask = ( group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) + scores_for_choice = scores_for_choice.masked_fill( + ~score_mask.bool(), -float("inf") + ) # Final topk from (possibly masked) scores topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] @@ -316,6 +320,8 @@ def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None: ``aux_free_router`` plugin). The counts are later consumed by the ``MoeAuxFreeBiasUpdateCallback`` at each training step. """ + if hasattr(moe_block, "training") and not moe_block.training: + return afb_counts = getattr(moe_block, "_afb_counts", None) if afb_counts is None: return diff --git a/tests/e2e/test_llama4_moe_aux_free.py b/tests/e2e/test_llama4_moe_aux_free.py index f431f619e..f307a8331 100644 --- a/tests/e2e/test_llama4_moe_aux_free.py +++ b/tests/e2e/test_llama4_moe_aux_free.py @@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase): { "base_model": "yujiepan/llama-4-tiny-random", "tokenizer_config": "yujiepan/llama-4-tiny-random", - "trust_remote_code": True, + "trust_remote_code": False, "flash_attention": False, "sequence_len": 512, "bf16": False, diff --git a/tests/e2e/test_moe_aux_parity.py b/tests/e2e/test_moe_aux_parity.py index 7f71c54ea..25d70baa1 100644 --- a/tests/e2e/test_moe_aux_parity.py +++ b/tests/e2e/test_moe_aux_parity.py @@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t Checks that aux-free training loss does not degrade beyond a small tolerance. """ +import gc import unittest +import torch + from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase): loss0 = _last_logged_loss(trainer0) assert loss0 is not None + # Release baseline resources before starting aux-free run + del model0, trainer0, dataset_meta0 + gc.collect() + torch.cuda.empty_cache() + # Aux-free: plugin + noaux_tc cfg1 = DictDefault(dict(base_cfg)) cfg1.output_dir = f"{temp_dir}/auxfree" diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py index 43457679f..43b77c8c2 100644 --- a/tests/unit/test_aux_free_adapters.py +++ b/tests/unit/test_aux_free_adapters.py @@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase): def test_ep_group_resolution_deferred_until_dist_ready(self): if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() + self.skipTest( + "Cannot safely test deferred EP group resolution when a process group is already initialized" + ) model, block = _build_bailing_model() cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)