address PR code review

This commit is contained in:
Wing Lian
2026-03-22 17:23:12 +00:00
parent 0a566d7a15
commit 6636e5de7e
9 changed files with 51 additions and 17 deletions

View File

@@ -4,6 +4,6 @@ from .args import AuxFreeRouterArgs
from .plugin import AuxFreeMoEPlugin from .plugin import AuxFreeMoEPlugin
__all__ = [ __all__ = [
"AuxFreeRouterArgs",
"AuxFreeMoEPlugin", "AuxFreeMoEPlugin",
"AuxFreeRouterArgs",
] ]

View File

@@ -88,12 +88,22 @@ class BaseMoEAdapter:
try: try:
model_or_layer.router_aux_loss_coef = 0.0 model_or_layer.router_aux_loss_coef = 0.0
except Exception: # pragma: no cover - non-critical except Exception: # pragma: no cover - non-critical
pass LOG.debug(
"disable_aux_loss: failed to set router_aux_loss_coef on %s",
type(model_or_layer).__name__,
exc_info=True,
)
def _register_aux_buffers( def _register_aux_buffers(
self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim
) -> None: ) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device p = next(moe_layer.parameters(), None)
b = next(moe_layer.buffers(), None)
device = (
p.device
if p is not None
else (b.device if b is not None else torch.device("cpu"))
)
if not hasattr(moe_layer, "_afb_bias"): if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer( moe_layer.register_buffer(
"_afb_bias", torch.zeros(handle.num_experts, device=device) "_afb_bias", torch.zeros(handle.num_experts, device=device)
@@ -275,7 +285,7 @@ class BailingAdapter(BaseMoEAdapter):
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype) scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = moe_layer._afb_bias bias = moe_layer._afb_bias
biased_scores = scores_unbiased + bias biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores) _, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx) weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1: if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20) denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
@@ -376,6 +386,8 @@ def discover_and_prepare_layers(
idx += 1 idx += 1
LOG.info( LOG.info(
f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing" "AuxFreeMoE: prepared %d %s layers for aux-free routing",
len(handles),
adapter.family,
) )
return handles return handles

View File

@@ -37,14 +37,14 @@ class AuxFreeRouterArgs(BaseModel):
moe_update_rate: float | None = Field( moe_update_rate: float | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. " "description": "Per-step bias update rate (gamma). Recommended: 0.005-0.05. "
"If unset, plugin default is 0.01." "If unset, plugin default is 0.01."
}, },
) )
moe_update_momentum: float | None = Field( moe_update_momentum: float | None = Field(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). " "description": "EMA momentum for expert load smoothing (0-1). "
"If unset, plugin default is 0.9." "If unset, plugin default is 0.9."
}, },
) )

View File

@@ -67,7 +67,7 @@ class AuxFreeShim:
else: else:
b = self.state.bias[layer_idx] b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1) _topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx) chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype) weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights return topk_idx, weights

View File

@@ -239,11 +239,17 @@ class AuxFreeMoEPlugin(BasePlugin):
return dist.group.WORLD return dist.group.WORLD
rank = dist.get_rank() rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size # All ranks must collectively create all EP subgroups in the same order
ranks = tuple(range(group_start, group_start + ep_size)) # to avoid deadlocks (dist.new_group is a collective operation).
if ranks not in self._ep_group_cache: world_size = world
self._ep_group_cache[ranks] = dist.new_group(ranks) my_group = None
return self._ep_group_cache[ranks] for group_start in range(0, world_size, ep_size):
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
if rank in ranks:
my_group = self._ep_group_cache[ranks]
return my_group
def add_callbacks_post_trainer(self, cfg, trainer): def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc": if getattr(cfg, "moe_balance_type", None) != "noaux_tc":

View File

@@ -177,7 +177,9 @@ def softmax_group_topk_routing(
score_mask = ( score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
) )
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
topk_weights = router_probs.gather(1, topk_indices) topk_weights = router_probs.gather(1, topk_indices)
@@ -275,7 +277,9 @@ def sigmoid_topk_routing(
score_mask = ( score_mask = (
group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E) group_mask.unsqueeze(-1).expand(-1, n_group, E // n_group).reshape(-1, E)
) )
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) scores_for_choice = scores_for_choice.masked_fill(
~score_mask.bool(), -float("inf")
)
# Final topk from (possibly masked) scores # Final topk from (possibly masked) scores
topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1] topk_indices = torch.topk(scores_for_choice, k=K, dim=-1, sorted=False)[1]
@@ -316,6 +320,8 @@ def _accumulate_afb_counts(moe_block, topk_indices: torch.Tensor) -> None:
``aux_free_router`` plugin). The counts are later consumed by the ``aux_free_router`` plugin). The counts are later consumed by the
``MoeAuxFreeBiasUpdateCallback`` at each training step. ``MoeAuxFreeBiasUpdateCallback`` at each training step.
""" """
if hasattr(moe_block, "training") and not moe_block.training:
return
afb_counts = getattr(moe_block, "_afb_counts", None) afb_counts = getattr(moe_block, "_afb_counts", None)
if afb_counts is None: if afb_counts is None:
return return

View File

@@ -21,7 +21,7 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
{ {
"base_model": "yujiepan/llama-4-tiny-random", "base_model": "yujiepan/llama-4-tiny-random",
"tokenizer_config": "yujiepan/llama-4-tiny-random", "tokenizer_config": "yujiepan/llama-4-tiny-random",
"trust_remote_code": True, "trust_remote_code": False,
"flash_attention": False, "flash_attention": False,
"sequence_len": 512, "sequence_len": 512,
"bf16": False, "bf16": False,

View File

@@ -3,8 +3,11 @@ Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-t
Checks that aux-free training loss does not degrade beyond a small tolerance. Checks that aux-free training loss does not degrade beyond a small tolerance.
""" """
import gc
import unittest import unittest
import torch
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
@@ -61,6 +64,11 @@ class TestMoeAuxParity(unittest.TestCase):
loss0 = _last_logged_loss(trainer0) loss0 = _last_logged_loss(trainer0)
assert loss0 is not None assert loss0 is not None
# Release baseline resources before starting aux-free run
del model0, trainer0, dataset_meta0
gc.collect()
torch.cuda.empty_cache()
# Aux-free: plugin + noaux_tc # Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg)) cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree" cfg1.output_dir = f"{temp_dir}/auxfree"

View File

@@ -390,7 +390,9 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_ep_group_resolution_deferred_until_dist_ready(self): def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
dist.destroy_process_group() self.skipTest(
"Cannot safely test deferred EP group resolution when a process group is already initialized"
)
model, block = _build_bailing_model() model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1) cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)