address PR code review

This commit is contained in:
Wing Lian
2026-03-22 17:23:12 +00:00
parent 0a566d7a15
commit 6636e5de7e
9 changed files with 51 additions and 17 deletions

View File

@@ -4,6 +4,6 @@ from .args import AuxFreeRouterArgs
from .plugin import AuxFreeMoEPlugin
__all__ = [
"AuxFreeRouterArgs",
"AuxFreeMoEPlugin",
"AuxFreeRouterArgs",
]

View File

@@ -88,12 +88,22 @@ class BaseMoEAdapter:
try:
model_or_layer.router_aux_loss_coef = 0.0
except Exception: # pragma: no cover - non-critical
pass
LOG.debug(
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
type(model_or_layer).__name__,
exc_info=True,
)
def _register_aux_buffers(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
p = next(moe_layer.parameters(), None)
b = next(moe_layer.buffers(), None)
device = (
p.device
if p is not None
else (b.device if b is not None else torch.device("cpu"))
)
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer(
"_afb_bias", torch.zeros(handle.num_experts, device=device)
@@ -275,7 +285,7 @@ class BailingAdapter(BaseMoEAdapter):
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = moe_layer._afb_bias
biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
_, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
@@ -376,6 +386,8 @@ def discover_and_prepare_layers(
idx += 1
LOG.info(
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing"
"AuxFreeMoE: prepared %d %s layers for aux-free routing",
len(handles),
adapter.family,
)
return handles

View File

@@ -37,14 +37,14 @@ class AuxFreeRouterArgs(BaseModel):
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. "
"description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
"If unset, plugin default is 0.01."
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). "
"description": "EMA momentum for expert load smoothing (0-1). "
"If unset, plugin default is 0.9."
},
)

View File

@@ -67,7 +67,7 @@ class AuxFreeShim:
else:
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
_topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights

View File

@@ -239,11 +239,17 @@ class AuxFreeMoEPlugin(BasePlugin):
return dist.group.WORLD
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
# All ranks must collectively create all EP subgroups in the same order
# to avoid deadlocks (dist.new_group is a collective operation).
world_size = world
my_group = None
for group_start in range(0, world_size, ep_size):
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
return self._ep_group_cache[ranks]
if rank in ranks:
my_group = self._ep_group_cache[ranks]
return my_group
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":

View File

@@ -177,7 +177,9 @@ def softmax_group_topk_routing(
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices)
@@ -275,7 +277,9 @@ def sigmoid_topk_routing(
score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
)
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
# Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
@@ -316,6 +320,8 @@ def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
``aux_free_router`` plugin). The counts are later consumed by the
``MoeAuxFreeBiasUpdateCallback`` at each training step.
"""
if hasattr(moe_block, "training") and not moe_block.training:
return
afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None:
return

View File

@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
{
"base_model": "yujiepan/llama-4-tiny-random",
"tokenizer_config": "yujiepan/llama-4-tiny-random",
"trust_remote_code": True,
"trust_remote_code": False,
"flash_attention": False,
"sequence_len": 512,
"bf16": False,

View File

@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import gc
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Release baseline resources before starting aux-free run
del model0, trainer0, dataset_meta0
gc.collect()
torch.cuda.empty_cache()
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"

View File

@@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
self.skipTest(
"Cannot safely test deferred EP group resolution when a process group is already initialized"
)
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)