chore: lint

This commit is contained in:
Wing Lian
2026-03-22 12:02:44 -04:00
parent 5acb1b0ade
commit 0a566d7a15
7 changed files with 33 additions and 18 deletions

View File

@@ -69,4 +69,3 @@ class AuxFreeRouterArgs(BaseModel):
"'ep' (expert-parallel group if available). Defaults to 'world' when unset." "'ep' (expert-parallel group if available). Defaults to 'world' when unset."
}, },
) )

View File

@@ -5,6 +5,7 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)
@@ -22,9 +23,17 @@ class AuxFreeConfig:
class AuxFreeState: class AuxFreeState:
"""Holds per-layer bias and EMA load buffers.""" """Holds per-layer bias and EMA load buffers."""
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig): def __init__(
self,
num_layers: int,
num_experts: int,
device: torch.device,
cfg: AuxFreeConfig,
):
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)] self.ema_load = [
torch.zeros(num_experts, device=device) for _ in range(num_layers)
]
self.cfg = cfg self.cfg = cfg
self.steps = 0 self.steps = 0
@@ -48,11 +57,13 @@ class AuxFreeShim:
self._prev_bias_sign: dict[int, torch.Tensor] = {} self._prev_bias_sign: dict[int, torch.Tensor] = {}
@torch.no_grad() @torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: def select_experts(
self, layer_idx: int, logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights.""" """Returns (topk_indices, weights) using biased selection and unbiased weights."""
module = self._layer_modules.get(layer_idx) module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_bias"): if module is not None and hasattr(module, "_afb_bias"):
b = getattr(module, "_afb_bias") b = module._afb_bias
else: else:
b = self.state.bias[layer_idx] b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer biased = logits + b # bias is a buffer
@@ -64,8 +75,8 @@ class AuxFreeShim:
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None: def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
"""Bind model buffers so shim updates stay in sync with patched layers.""" """Bind model buffers so shim updates stay in sync with patched layers."""
self._layer_modules[layer_idx] = module self._layer_modules[layer_idx] = module
bias = getattr(module, "_afb_bias") bias = module._afb_bias
ema = getattr(module, "_afb_ema") ema = module._afb_ema
# Keep state views pointing to the same tensors to avoid drift. # Keep state views pointing to the same tensors to avoid drift.
if layer_idx < len(self.state.bias): if layer_idx < len(self.state.bias):
self.state.bias[layer_idx] = bias self.state.bias[layer_idx] = bias
@@ -100,8 +111,8 @@ class AuxFreeShim:
return return
module = self._layer_modules.get(layer_idx) module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_ema"): if module is not None and hasattr(module, "_afb_ema"):
ema = getattr(module, "_afb_ema") ema = module._afb_ema
bias = getattr(module, "_afb_bias") bias = module._afb_bias
else: else:
ema = self.state.ema_load[layer_idx] ema = self.state.ema_load[layer_idx]
bias = self.state.bias[layer_idx] bias = self.state.bias[layer_idx]

View File

@@ -63,7 +63,9 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta) model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None) patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
assert patched is not None, "Llama 4 MoE layer was not patched by aux-free plugin" assert patched is not None, (
"Llama 4 MoE layer was not patched by aux-free plugin"
)
assert patched._afb_bias.ndim == 1 assert patched._afb_bias.ndim == 1
assert patched._afb_counts.ndim == 1 assert patched._afb_counts.ndim == 1
check_model_output_exists(temp_dir, cfg) check_model_output_exists(temp_dir, cfg)

View File

@@ -8,7 +8,7 @@ import torch
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists, with_temp_dir
@@ -67,7 +67,7 @@ class TestMoeAuxFree(unittest.TestCase):
# Inspect model modules for a patched MoE layer # Inspect model modules for a patched MoE layer
patched = None patched = None
for m in model.modules(): for m in model.modules():
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True: if hasattr(m, "_afb_patched") and m._afb_patched is True:
patched = m patched = m
break break
assert patched is not None, "No MoE layer patched by aux-free plugin" assert patched is not None, "No MoE layer patched by aux-free plugin"

View File

@@ -7,7 +7,7 @@ import unittest
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir from .utils import with_temp_dir

View File

@@ -4,11 +4,9 @@ E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
import unittest import unittest
import torch
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import check_model_output_exists, with_temp_dir
@@ -65,7 +63,9 @@ class TestQwen3MoeAuxFree(unittest.TestCase):
# check that at least one sparse MoE block has been patched # check that at least one sparse MoE block has been patched
found = False found = False
for m in model.modules(): for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"): if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
m, "_afb_patched"
):
assert m._afb_patched is True assert m._afb_patched is True
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1 assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1 assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1

View File

@@ -12,6 +12,7 @@ from pathlib import Path
import torch import torch
from packaging import version from packaging import version
try: try:
from tbparse import SummaryReader from tbparse import SummaryReader
except ImportError: # pragma: no cover - optional dependency except ImportError: # pragma: no cover - optional dependency
@@ -189,7 +190,9 @@ def check_tensorboard(
helper function to parse and check tensorboard logs helper function to parse and check tensorboard logs
""" """
if SummaryReader is None: if SummaryReader is None:
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions") raise unittest.SkipTest(
"tbparse is not installed; skipping tensorboard assertions"
)
tb_log_path = most_recent_subdir(temp_run_dir) tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0]) event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file) reader = SummaryReader(event_file)