From 46d677876e7f52f38aea91a65bca25a660adfceb Mon Sep 17 00:00:00 2001 From: lhl Date: Tue, 28 Oct 2025 10:01:07 +0000 Subject: [PATCH] tests: add ring and llama4 aux-free smokes - add hf tiny model coverage for Ring 2.0 and Llama 4 adapters - broaden bailing adapter detection for ring configs --- .../integrations/aux_free_router/adapters.py | 10 ++- tests/e2e/test_llama4_moe_aux_free.py | 74 ++++++++++++++++++ tests/e2e/test_ring_moe_aux_free.py | 75 +++++++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 tests/e2e/test_llama4_moe_aux_free.py create mode 100644 tests/e2e/test_ring_moe_aux_free.py diff --git a/src/axolotl/integrations/aux_free_router/adapters.py b/src/axolotl/integrations/aux_free_router/adapters.py index e349c268b..ea49c6936 100644 --- a/src/axolotl/integrations/aux_free_router/adapters.py +++ b/src/axolotl/integrations/aux_free_router/adapters.py @@ -189,8 +189,14 @@ class BailingAdapter(BaseMoEAdapter): family = "bailing_moe" def matches(self, model: nn.Module) -> bool: - model_type = getattr(getattr(model, "config", object()), "model_type", "") - return model_type in ("bailing_moe", "bailing_moe_v2") + cfg = getattr(model, "config", None) + if cfg is None: + return False + model_type = getattr(cfg, "model_type", "") or "" + if model_type in ("bailing_moe", "bailing_moe_v2", "ring_moe", "ring"): + return True + cfg_name = cfg.__class__.__name__.lower() + return "bailingmoev2" in cfg_name or "ring" in cfg_name def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: for m in model.modules(): diff --git a/tests/e2e/test_llama4_moe_aux_free.py b/tests/e2e/test_llama4_moe_aux_free.py new file mode 100644 index 000000000..bc5341dd9 --- /dev/null +++ b/tests/e2e/test_llama4_moe_aux_free.py @@ -0,0 +1,74 @@ +""" +E2E smoke test for Llama 4 aux-loss-free routing via plugin +""" + +import unittest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestLlama4MoeAuxFree(unittest.TestCase): + """Smoke test to ensure aux-free plugin patches Llama 4 MoE correctly.""" + + @with_temp_dir + def test_llama4_aux_free_smoke(self, temp_dir): + cfg = DictDefault( + { + "base_model": "yujiepan/llama-4-tiny-random", + "tokenizer_config": "yujiepan/llama-4-tiny-random", + "trust_remote_code": True, + "flash_attention": False, + "sequence_len": 512, + "bf16": False, + "fp16": False, + "val_set_size": 0.02, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_steps": 0, + "eval_steps": 0, + "save_first_step": False, + "plugins": [ + "axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin", + ], + "moe_balance_type": "noaux_tc", + "moe_update_rate": 0.01, + "moe_update_momentum": 0.9, + "moe_bias_cap": 2.0, + "moe_afb_telemetry_interval": 1, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + prepare_plugins(cfg) + dataset_meta = load_datasets(cfg=cfg) + + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) + + patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None) + assert patched is not None, "Llama 4 MoE layer was not patched by aux-free plugin" + assert patched._afb_bias.ndim == 1 + assert patched._afb_counts.ndim == 1 + check_model_output_exists(temp_dir, cfg) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/e2e/test_ring_moe_aux_free.py b/tests/e2e/test_ring_moe_aux_free.py new file mode 100644 index 000000000..1905582d7 --- /dev/null +++ b/tests/e2e/test_ring_moe_aux_free.py @@ -0,0 +1,75 @@ +""" +E2E smoke test for Ring 2.0 aux-loss-free routing via plugin +""" + +import unittest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, with_temp_dir + + +class TestRingMoeAuxFree(unittest.TestCase): + """Smoke test to ensure aux-free plugin patches Ring Mini 2.0 correctly.""" + + @with_temp_dir + def test_ring_aux_free_smoke(self, temp_dir): + cfg = DictDefault( + { + "base_model": "yujiepan/ring-tiny-random", + "tokenizer_config": "yujiepan/ring-tiny-random", + "trust_remote_code": True, + "flash_attention": False, + "sequence_len": 512, + "bf16": False, + "fp16": False, + "val_set_size": 0.02, + "special_tokens": {}, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "max_steps": 5, + "save_steps": 0, + "eval_steps": 0, + "save_first_step": False, + # Aux-free plugin config + "plugins": [ + "axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin", + ], + "moe_balance_type": "noaux_tc", + "moe_update_rate": 0.01, + "moe_update_momentum": 0.9, + "moe_bias_cap": 2.0, + "moe_afb_telemetry_interval": 1, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + prepare_plugins(cfg) + dataset_meta = load_datasets(cfg=cfg) + + model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) + + patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None) + assert patched is not None, "Ring MoE layer was not patched by aux-free plugin" + assert patched._afb_bias.ndim == 1 + assert patched._afb_counts.ndim == 1 + check_model_output_exists(temp_dir, cfg) + + +if __name__ == "__main__": + unittest.main()