improve: align aux-free telemetry with Trainer logging

This commit is contained in:
lhl
2025-11-11 17:00:48 +00:00
committed by Wing Lian
parent 966a4555db
commit 676d5e855d
6 changed files with 39 additions and 41 deletions

View File

@@ -30,7 +30,6 @@ Config keys
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
@@ -38,7 +37,7 @@ Compatibility
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval.
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainers `logging_steps` cadence.
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
Telemetry metrics
@@ -47,5 +46,5 @@ Telemetry metrics
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
Usage tips
- Leave `moe_afb_telemetry_interval` unset to log on the Trainers `logging_steps`. Increase the interval for large jobs to reduce log volume.
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainers logging cadence.
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.

View File

@@ -17,7 +17,6 @@ class AuxFreeConfig:
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
telemetry_interval: Optional[int] = None
class AuxFreeState:

View File

@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
from __future__ import annotations
from typing import Optional
from typing import Optional, Any
import torch
import torch.distributed as dist
@@ -29,22 +29,17 @@ LOG = get_logger(__name__)
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
"""Post-step callback to update aux-free biases from accumulated expert counts.
Note: The current revision expects per-layer counts to be accumulated on each
MoE layer as a buffer named `_afb_counts` during forward (to be added with
routing patches in a follow-up).
"""
"""Post-step callback to update aux-free biases from accumulated expert counts."""
def __init__(
self,
shim: AuxFreeShim,
layer_modules: list[torch.nn.Module],
telemetry_interval: Optional[int] = None,
trainer: Any,
):
self.shim = shim
self.layer_modules = layer_modules
self.telemetry_interval = telemetry_interval
self.trainer = trainer
self._prev_bias_sign: dict[int, torch.Tensor] = {}
self._telemetry_buffer: dict[int, dict[str, float]] = {}
@@ -71,17 +66,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
# reset step counts
counts.zero_()
if self._should_log(args, state):
if self._telemetry_buffer:
if not hasattr(state, "log_history"):
state.log_history = []
log_entry = {"step": state.global_step}
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
prefix = f"moe_afb/l{layer_idx}_"
for key, value in metrics.items():
log_entry[f"{prefix}{key}"] = value
state.log_history.append(log_entry)
control.should_log = True
if self._should_log(args, state) and self._telemetry_buffer:
logs: dict[str, float] = {}
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
prefix = f"moe_afb/l{layer_idx}_"
for key, value in metrics.items():
logs[f"{prefix}{key}"] = value
if logs and hasattr(self.trainer, "log"):
self.trainer.log(logs)
self._telemetry_buffer.clear()
return control
@@ -123,15 +115,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
}
def _should_log(self, args, state) -> bool:
interval = (
self.telemetry_interval
if self.telemetry_interval is not None and self.telemetry_interval > 0
else getattr(args, "logging_steps", 0)
)
interval = max(1, int(interval)) if interval else 0
if interval == 0:
interval = getattr(args, "logging_steps", 0)
if not interval:
return False
return state.global_step % interval == 0
try:
interval = max(1, int(interval))
except (TypeError, ValueError):
return False
return interval > 0 and state.global_step % interval == 0
class AuxFreeMoEPlugin(BasePlugin):
@@ -165,14 +156,12 @@ class AuxFreeMoEPlugin(BasePlugin):
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None)
af_cfg = AuxFreeConfig(
rate=rate,
momentum=momentum,
bias_cap=bias_cap,
warmup_steps=warmup,
sync_group=sync_group,
telemetry_interval=telemetry_interval,
)
# Discover layers to count the number and experts for state sizing
@@ -249,7 +238,7 @@ class AuxFreeMoEPlugin(BasePlugin):
cb = MoeAuxFreeBiasUpdateCallback(
self._shim,
layers,
telemetry_interval=self._shim.state.cfg.telemetry_interval,
trainer,
)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -52,7 +52,6 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
"moe_afb_telemetry_interval": 1,
}
)

View File

@@ -53,7 +53,6 @@ class TestRingMoeAuxFree(unittest.TestCase):
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
"moe_afb_telemetry_interval": 1,
}
)

View File

@@ -149,9 +149,6 @@ def _build_mixtral_model():
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
if args is None:
args = SimpleNamespace(logging_steps=1)
if state is None:
@@ -163,6 +160,22 @@ def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
should_save=False,
should_training_stop=False,
)
class DummyTrainer:
def __init__(self, state_obj, control_obj):
self.state = state_obj
self.control = control_obj
def log(self, logs):
output = dict(logs)
output["step"] = self.state.global_step
self.state.log_history.append(output)
self.control.should_log = True
dummy_trainer = DummyTrainer(state, control)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
callback.on_step_end(args=args, state=state, control=control)
return state, control
@@ -277,7 +290,7 @@ class TestAuxFreeAdapters(unittest.TestCase):
def test_telemetry_logging(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0
cfg = _cfg(moe_afb_telemetry_interval=1)
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)