tests: extend aux-free coverage

- add warmup, EP sync, and mixtral parity unit checks
This commit is contained in:
lhl
2025-10-28 08:08:13 +00:00
committed by Wing Lian
parent a0019021dd
commit 949cdf01eb

View File

@@ -1,8 +1,11 @@
import os
import sys import sys
import tempfile
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from importlib import util as importlib_util from importlib import util as importlib_util
from pathlib import Path from pathlib import Path
@@ -118,6 +121,33 @@ def _build_llama4_model():
return DummyModel(layer), layer return DummyModel(layer), layer
def _build_mixtral_model():
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
config = MixtralConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
)
layer = MixtralSparseMoeBlock(config)
layer.config = config
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="mixtral")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg): def _run_callback(plugin, cfg):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered" assert callbacks, "expected aux-free callback to be registered"
@@ -157,6 +187,81 @@ class TestAuxFreeAdapters(unittest.TestCase):
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))) self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
cfg = _cfg(moe_afb_warmup_steps=2)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone())
self.assertTrue(torch.allclose(baseline_out, patched_out))
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()