improve: align aux-free telemetry with Trainer logging
This commit is contained in:
@@ -30,7 +30,6 @@ Config keys
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
- moe_afb_telemetry_interval: emit router telemetry every N optimizer steps (defaults to `logging_steps` when unset).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
@@ -38,7 +37,7 @@ Compatibility
|
||||
|
||||
Notes
|
||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction at the configured interval.
|
||||
- Telemetry: logs per-layer min/mean/max token loads, `|bias| max`, and bias sign flip fraction using the Trainer’s `logging_steps` cadence.
|
||||
- Sample packing: packed batches are compatible with aux-free routing. Because load counts are accumulated on-device per expert before reduction, packing tends to smooth token histograms and reduce bias oscillation. Keep `pad_to_sequence_len: true` when packing to preserve the target token budget per expert.
|
||||
|
||||
Telemetry metrics
|
||||
@@ -47,5 +46,5 @@ Telemetry metrics
|
||||
- `moe_afb/l{idx}_bias_sign_flip_frac`: fraction of experts whose bias sign changed since the previous step (simple oscillation indicator).
|
||||
|
||||
Usage tips
|
||||
- Leave `moe_afb_telemetry_interval` unset to log on the Trainer’s `logging_steps`. Increase the interval for large jobs to reduce log volume.
|
||||
- Increase `logging_steps` if router telemetry becomes noisy for large jobs—the plugin follows the Trainer’s logging cadence.
|
||||
- Compare aux-free vs. aux-loss load metrics by plotting the `load_*` series; aux-free typically tightens min/max spread without the auxiliary loss term.
|
||||
|
||||
@@ -17,7 +17,6 @@ class AuxFreeConfig:
|
||||
bias_cap: float = 2.0
|
||||
warmup_steps: int = 0
|
||||
sync_group: str = "world" # or "ep"
|
||||
telemetry_interval: Optional[int] = None
|
||||
|
||||
|
||||
class AuxFreeState:
|
||||
|
||||
@@ -6,7 +6,7 @@ unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -29,22 +29,17 @@ LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
"""Post-step callback to update aux-free biases from accumulated expert counts.
|
||||
|
||||
Note: The current revision expects per-layer counts to be accumulated on each
|
||||
MoE layer as a buffer named `_afb_counts` during forward (to be added with
|
||||
routing patches in a follow-up).
|
||||
"""
|
||||
"""Post-step callback to update aux-free biases from accumulated expert counts."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shim: AuxFreeShim,
|
||||
layer_modules: list[torch.nn.Module],
|
||||
telemetry_interval: Optional[int] = None,
|
||||
trainer: Any,
|
||||
):
|
||||
self.shim = shim
|
||||
self.layer_modules = layer_modules
|
||||
self.telemetry_interval = telemetry_interval
|
||||
self.trainer = trainer
|
||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||
self._telemetry_buffer: dict[int, dict[str, float]] = {}
|
||||
|
||||
@@ -71,17 +66,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
|
||||
if self._should_log(args, state):
|
||||
if self._telemetry_buffer:
|
||||
if not hasattr(state, "log_history"):
|
||||
state.log_history = []
|
||||
log_entry = {"step": state.global_step}
|
||||
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||
prefix = f"moe_afb/l{layer_idx}_"
|
||||
for key, value in metrics.items():
|
||||
log_entry[f"{prefix}{key}"] = value
|
||||
state.log_history.append(log_entry)
|
||||
control.should_log = True
|
||||
if self._should_log(args, state) and self._telemetry_buffer:
|
||||
logs: dict[str, float] = {}
|
||||
for layer_idx, metrics in sorted(self._telemetry_buffer.items()):
|
||||
prefix = f"moe_afb/l{layer_idx}_"
|
||||
for key, value in metrics.items():
|
||||
logs[f"{prefix}{key}"] = value
|
||||
if logs and hasattr(self.trainer, "log"):
|
||||
self.trainer.log(logs)
|
||||
self._telemetry_buffer.clear()
|
||||
return control
|
||||
|
||||
@@ -123,15 +115,14 @@ class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
}
|
||||
|
||||
def _should_log(self, args, state) -> bool:
|
||||
interval = (
|
||||
self.telemetry_interval
|
||||
if self.telemetry_interval is not None and self.telemetry_interval > 0
|
||||
else getattr(args, "logging_steps", 0)
|
||||
)
|
||||
interval = max(1, int(interval)) if interval else 0
|
||||
if interval == 0:
|
||||
interval = getattr(args, "logging_steps", 0)
|
||||
if not interval:
|
||||
return False
|
||||
return state.global_step % interval == 0
|
||||
try:
|
||||
interval = max(1, int(interval))
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return interval > 0 and state.global_step % interval == 0
|
||||
|
||||
|
||||
class AuxFreeMoEPlugin(BasePlugin):
|
||||
@@ -165,14 +156,12 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||
telemetry_interval = getattr(cfg, "moe_afb_telemetry_interval", None)
|
||||
af_cfg = AuxFreeConfig(
|
||||
rate=rate,
|
||||
momentum=momentum,
|
||||
bias_cap=bias_cap,
|
||||
warmup_steps=warmup,
|
||||
sync_group=sync_group,
|
||||
telemetry_interval=telemetry_interval,
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
@@ -249,7 +238,7 @@ class AuxFreeMoEPlugin(BasePlugin):
|
||||
cb = MoeAuxFreeBiasUpdateCallback(
|
||||
self._shim,
|
||||
layers,
|
||||
telemetry_interval=self._shim.state.cfg.telemetry_interval,
|
||||
trainer,
|
||||
)
|
||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||
return [cb]
|
||||
|
||||
@@ -52,7 +52,6 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
"moe_afb_telemetry_interval": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ class TestRingMoeAuxFree(unittest.TestCase):
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
"moe_afb_telemetry_interval": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -149,9 +149,6 @@ def _build_mixtral_model():
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
if args is None:
|
||||
args = SimpleNamespace(logging_steps=1)
|
||||
if state is None:
|
||||
@@ -163,6 +160,22 @@ def _run_callback(plugin, cfg, *, args=None, state=None, control=None):
|
||||
should_save=False,
|
||||
should_training_stop=False,
|
||||
)
|
||||
|
||||
class DummyTrainer:
|
||||
def __init__(self, state_obj, control_obj):
|
||||
self.state = state_obj
|
||||
self.control = control_obj
|
||||
|
||||
def log(self, logs):
|
||||
output = dict(logs)
|
||||
output["step"] = self.state.global_step
|
||||
self.state.log_history.append(output)
|
||||
self.control.should_log = True
|
||||
|
||||
dummy_trainer = DummyTrainer(state, control)
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=dummy_trainer)
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
callback.on_step_end(args=args, state=state, control=control)
|
||||
return state, control
|
||||
|
||||
@@ -277,7 +290,7 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_telemetry_logging(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0
|
||||
cfg = _cfg(moe_afb_telemetry_interval=1)
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user