diff --git a/tests/unit/test_aux_free_adapters.py b/tests/unit/test_aux_free_adapters.py index ad7d91e48..214126892 100644 --- a/tests/unit/test_aux_free_adapters.py +++ b/tests/unit/test_aux_free_adapters.py @@ -1,8 +1,11 @@ +import os import sys +import tempfile import unittest from types import SimpleNamespace import torch +import torch.distributed as dist import torch.nn as nn from importlib import util as importlib_util from pathlib import Path @@ -118,6 +121,33 @@ def _build_llama4_model(): return DummyModel(layer), layer +def _build_mixtral_model(): + from transformers import MixtralConfig + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + config = MixtralConfig( + hidden_size=16, + intermediate_size=32, + num_local_experts=4, + num_experts_per_tok=2, + num_attention_heads=2, + num_key_value_heads=2, + ) + layer = MixtralSparseMoeBlock(config) + layer.config = config + + class DummyModel(nn.Module): + def __init__(self, moe_layer): + super().__init__() + self.moe = moe_layer + self.config = SimpleNamespace(model_type="mixtral") + + def forward(self, hidden_states): + return self.moe(hidden_states) + + return DummyModel(layer), layer + + def _run_callback(plugin, cfg): callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) assert callbacks, "expected aux-free callback to be registered" @@ -157,6 +187,81 @@ class TestAuxFreeAdapters(unittest.TestCase): self.assertEqual(torch.count_nonzero(layer._afb_counts), 0) self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema))) + def test_bias_warmup_respected(self): + model, block = _build_bailing_model() + cfg = _cfg(moe_afb_warmup_steps=2) + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) + self.assertTrue(callbacks) + callback = callbacks[0] + dummy = SimpleNamespace() + + def _step(): + hidden = torch.randn(2, 3, block.config.hidden_size) + block(hidden) + callback.on_step_end(args=dummy, state=dummy, control=dummy) + + # Warmup steps should leave bias untouched. + _step() + self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) + + _step() + self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias))) + + # Third step exceeds warmup -> bias should update. + _step() + self.assertGreater(torch.count_nonzero(block._afb_bias), 0) + + def test_mixtral_adapter_respects_native_forward(self): + model, layer = _build_mixtral_model() + layer.jitter_noise = 0.0 # avoid stochasticity for comparison + + hidden_dim = layer.config.hidden_size + hidden = torch.randn(2, 3, hidden_dim) + baseline_out, baseline_logits = layer(hidden.clone()) + + cfg = _cfg() + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + patched_out, patched_logits = layer(hidden.clone()) + self.assertTrue(torch.allclose(baseline_out, patched_out)) + self.assertTrue(torch.allclose(baseline_logits, patched_logits)) + self.assertGreater(torch.count_nonzero(layer._afb_counts), 0) + _run_callback(plugin, cfg) + + def test_ep_group_resolution_deferred_until_dist_ready(self): + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + + model, block = _build_bailing_model() + cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1) + plugin = AuxFreeMoEPlugin() + plugin.post_model_build(cfg, model) + + self.assertIsNotNone(plugin._shim) + self.assertIsNone(plugin._shim.ep_group) + + callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace()) + self.assertTrue(callbacks) + callback = callbacks[0] + dummy = SimpleNamespace() + + tmp_init = tempfile.NamedTemporaryFile(delete=False) + tmp_init.close() + init_method = f"file://{tmp_init.name}" + dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0) + try: + hidden = torch.randn(2, 3, block.config.hidden_size) + block(hidden) + callback.on_step_end(args=dummy, state=dummy, control=dummy) + self.assertIs(plugin._shim.ep_group, dist.group.WORLD) + finally: + dist.destroy_process_group() + os.unlink(tmp_init.name) + if __name__ == "__main__": unittest.main()