chore: lint
This commit is contained in:
@@ -63,7 +63,9 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||
assert patched is not None, "Llama 4 MoE layer was not patched by aux-free plugin"
|
||||
assert patched is not None, (
|
||||
"Llama 4 MoE layer was not patched by aux-free plugin"
|
||||
)
|
||||
assert patched._afb_bias.ndim == 1
|
||||
assert patched._afb_counts.ndim == 1
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -67,7 +67,7 @@ class TestMoeAuxFree(unittest.TestCase):
|
||||
# Inspect model modules for a patched MoE layer
|
||||
patched = None
|
||||
for m in model.modules():
|
||||
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
|
||||
if hasattr(m, "_afb_patched") and m._afb_patched is True:
|
||||
patched = m
|
||||
break
|
||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
@@ -4,11 +4,9 @@ E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,7 +63,9 @@ class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||
# check that at least one sparse MoE block has been patched
|
||||
found = False
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
|
||||
m, "_afb_patched"
|
||||
):
|
||||
assert m._afb_patched is True
|
||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||
|
||||
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from tbparse import SummaryReader
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
@@ -189,7 +190,9 @@ def check_tensorboard(
|
||||
helper function to parse and check tensorboard logs
|
||||
"""
|
||||
if SummaryReader is None:
|
||||
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
||||
raise unittest.SkipTest(
|
||||
"tbparse is not installed; skipping tensorboard assertions"
|
||||
)
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
|
||||
Reference in New Issue
Block a user