tests: extend aux-free coverage
- add warmup, EP sync, and mixtral parity unit checks
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
@@ -118,6 +121,33 @@ def _build_llama4_model():
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_mixtral_model():
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
layer = MixtralSparseMoeBlock(config)
|
||||
layer.config = config
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="mixtral")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
@@ -157,6 +187,81 @@ class TestAuxFreeAdapters(unittest.TestCase):
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_respects_native_forward(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
baseline_out, baseline_logits = layer(hidden.clone())
|
||||
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
patched_out, patched_logits = layer(hidden.clone())
|
||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user