diff --git a/src/axolotl/integrations/aux_free_router/args.py b/src/axolotl/integrations/aux_free_router/args.py index d284d4d66..0f4cc4035 100644 --- a/src/axolotl/integrations/aux_free_router/args.py +++ b/src/axolotl/integrations/aux_free_router/args.py @@ -69,4 +69,3 @@ class AuxFreeRouterArgs(BaseModel): "'ep' (expert-parallel group if available). Defaults to 'world' when unset." }, ) - diff --git a/src/axolotl/integrations/aux_free_router/core.py b/src/axolotl/integrations/aux_free_router/core.py index 30b180547..34c9db2a8 100644 --- a/src/axolotl/integrations/aux_free_router/core.py +++ b/src/axolotl/integrations/aux_free_router/core.py @@ -5,6 +5,7 @@ from typing import Optional import torch import torch.distributed as dist + from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -22,9 +23,17 @@ class AuxFreeConfig: class AuxFreeState: """Holds per-layer bias and EMA load buffers.""" - def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig): + def __init__( + self, + num_layers: int, + num_experts: int, + device: torch.device, + cfg: AuxFreeConfig, + ): self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] - self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] + self.ema_load = [ + torch.zeros(num_experts, device=device) for _ in range(num_layers) + ] self.cfg = cfg self.steps = 0 @@ -48,11 +57,13 @@ class AuxFreeShim: self._prev_bias_sign: dict[int, torch.Tensor] = {} @torch.no_grad() - def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + def select_experts( + self, layer_idx: int, logits: torch.Tensor, top_k: int + ) -> tuple[torch.Tensor, torch.Tensor]: """Returns (topk_indices, weights) using biased selection and unbiased weights.""" module = self._layer_modules.get(layer_idx) if module is not None and hasattr(module, "_afb_bias"): - b = getattr(module, "_afb_bias") + b = module._afb_bias else: b = self.state.bias[layer_idx] biased = logits + b # bias is a buffer @@ -64,8 +75,8 @@ class AuxFreeShim: def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None: """Bind model buffers so shim updates stay in sync with patched layers.""" self._layer_modules[layer_idx] = module - bias = getattr(module, "_afb_bias") - ema = getattr(module, "_afb_ema") + bias = module._afb_bias + ema = module._afb_ema # Keep state views pointing to the same tensors to avoid drift. if layer_idx < len(self.state.bias): self.state.bias[layer_idx] = bias @@ -100,8 +111,8 @@ class AuxFreeShim: return module = self._layer_modules.get(layer_idx) if module is not None and hasattr(module, "_afb_ema"): - ema = getattr(module, "_afb_ema") - bias = getattr(module, "_afb_bias") + ema = module._afb_ema + bias = module._afb_bias else: ema = self.state.ema_load[layer_idx] bias = self.state.bias[layer_idx] diff --git a/tests/e2e/test_llama4_moe_aux_free.py b/tests/e2e/test_llama4_moe_aux_free.py index 9385f7733..f431f619e 100644 --- a/tests/e2e/test_llama4_moe_aux_free.py +++ b/tests/e2e/test_llama4_moe_aux_free.py @@ -63,7 +63,9 @@ class TestLlama4MoeAuxFree(unittest.TestCase): model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None) - assert patched is not None, "Llama 4 MoE layer was not patched by aux-free plugin" + assert patched is not None, ( + "Llama 4 MoE layer was not patched by aux-free plugin" + ) assert patched._afb_bias.ndim == 1 assert patched._afb_counts.ndim == 1 check_model_output_exists(temp_dir, cfg) diff --git a/tests/e2e/test_moe_aux_free.py b/tests/e2e/test_moe_aux_free.py index 237ca490c..6ffee20ea 100644 --- a/tests/e2e/test_moe_aux_free.py +++ b/tests/e2e/test_moe_aux_free.py @@ -8,7 +8,7 @@ import torch from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -67,7 +67,7 @@ class TestMoeAuxFree(unittest.TestCase): # Inspect model modules for a patched MoE layer patched = None for m in model.modules(): - if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True: + if hasattr(m, "_afb_patched") and m._afb_patched is True: patched = m break assert patched is not None, "No MoE layer patched by aux-free plugin" diff --git a/tests/e2e/test_moe_aux_parity.py b/tests/e2e/test_moe_aux_parity.py index d048bacdd..7f71c54ea 100644 --- a/tests/e2e/test_moe_aux_parity.py +++ b/tests/e2e/test_moe_aux_parity.py @@ -7,7 +7,7 @@ import unittest from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from .utils import with_temp_dir diff --git a/tests/e2e/test_qwen3_moe_aux_free.py b/tests/e2e/test_qwen3_moe_aux_free.py index 8f968e9ff..012bd0471 100644 --- a/tests/e2e/test_qwen3_moe_aux_free.py +++ b/tests/e2e/test_qwen3_moe_aux_free.py @@ -4,11 +4,9 @@ E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny import unittest -import torch - from axolotl.common.datasets import load_datasets from axolotl.train import train -from axolotl.utils.config import normalize_config, validate_config, prepare_plugins +from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from .utils import check_model_output_exists, with_temp_dir @@ -65,7 +63,9 @@ class TestQwen3MoeAuxFree(unittest.TestCase): # check that at least one sparse MoE block has been patched found = False for m in model.modules(): - if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"): + if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr( + m, "_afb_patched" + ): assert m._afb_patched is True assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1 assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1 diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index ef6ed3dfd..755a65fef 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -12,6 +12,7 @@ from pathlib import Path import torch from packaging import version + try: from tbparse import SummaryReader except ImportError: # pragma: no cover - optional dependency @@ -189,7 +190,9 @@ def check_tensorboard( helper function to parse and check tensorboard logs """ if SummaryReader is None: - raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions") + raise unittest.SkipTest( + "tbparse is not installed; skipping tensorboard assertions" + ) tb_log_path = most_recent_subdir(temp_run_dir) event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) reader = SummaryReader(event_file)