chore: lint
This commit is contained in:
@@ -69,4 +69,3 @@ class AuxFreeRouterArgs(BaseModel):
|
|||||||
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
"'ep' (expert-parallel group if available). Defaults to 'world' when unset."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -22,9 +23,17 @@ class AuxFreeConfig:
|
|||||||
class AuxFreeState:
|
class AuxFreeState:
|
||||||
"""Holds per-layer bias and EMA load buffers."""
|
"""Holds per-layer bias and EMA load buffers."""
|
||||||
|
|
||||||
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_experts: int,
|
||||||
|
device: torch.device,
|
||||||
|
cfg: AuxFreeConfig,
|
||||||
|
):
|
||||||
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||||
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
self.ema_load = [
|
||||||
|
torch.zeros(num_experts, device=device) for _ in range(num_layers)
|
||||||
|
]
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.steps = 0
|
self.steps = 0
|
||||||
|
|
||||||
@@ -48,11 +57,13 @@ class AuxFreeShim:
|
|||||||
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
self._prev_bias_sign: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
def select_experts(
|
||||||
|
self, layer_idx: int, logits: torch.Tensor, top_k: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||||
module = self._layer_modules.get(layer_idx)
|
module = self._layer_modules.get(layer_idx)
|
||||||
if module is not None and hasattr(module, "_afb_bias"):
|
if module is not None and hasattr(module, "_afb_bias"):
|
||||||
b = getattr(module, "_afb_bias")
|
b = module._afb_bias
|
||||||
else:
|
else:
|
||||||
b = self.state.bias[layer_idx]
|
b = self.state.bias[layer_idx]
|
||||||
biased = logits + b # bias is a buffer
|
biased = logits + b # bias is a buffer
|
||||||
@@ -64,8 +75,8 @@ class AuxFreeShim:
|
|||||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||||
self._layer_modules[layer_idx] = module
|
self._layer_modules[layer_idx] = module
|
||||||
bias = getattr(module, "_afb_bias")
|
bias = module._afb_bias
|
||||||
ema = getattr(module, "_afb_ema")
|
ema = module._afb_ema
|
||||||
# Keep state views pointing to the same tensors to avoid drift.
|
# Keep state views pointing to the same tensors to avoid drift.
|
||||||
if layer_idx < len(self.state.bias):
|
if layer_idx < len(self.state.bias):
|
||||||
self.state.bias[layer_idx] = bias
|
self.state.bias[layer_idx] = bias
|
||||||
@@ -100,8 +111,8 @@ class AuxFreeShim:
|
|||||||
return
|
return
|
||||||
module = self._layer_modules.get(layer_idx)
|
module = self._layer_modules.get(layer_idx)
|
||||||
if module is not None and hasattr(module, "_afb_ema"):
|
if module is not None and hasattr(module, "_afb_ema"):
|
||||||
ema = getattr(module, "_afb_ema")
|
ema = module._afb_ema
|
||||||
bias = getattr(module, "_afb_bias")
|
bias = module._afb_bias
|
||||||
else:
|
else:
|
||||||
ema = self.state.ema_load[layer_idx]
|
ema = self.state.ema_load[layer_idx]
|
||||||
bias = self.state.bias[layer_idx]
|
bias = self.state.bias[layer_idx]
|
||||||
|
|||||||
@@ -63,7 +63,9 @@ class TestLlama4MoeAuxFree(unittest.TestCase):
|
|||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
patched = next((m for m in model.modules() if hasattr(m, "_afb_bias")), None)
|
||||||
assert patched is not None, "Llama 4 MoE layer was not patched by aux-free plugin"
|
assert patched is not None, (
|
||||||
|
"Llama 4 MoE layer was not patched by aux-free plugin"
|
||||||
|
)
|
||||||
assert patched._afb_bias.ndim == 1
|
assert patched._afb_bias.ndim == 1
|
||||||
assert patched._afb_counts.ndim == 1
|
assert patched._afb_counts.ndim == 1
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -67,7 +67,7 @@ class TestMoeAuxFree(unittest.TestCase):
|
|||||||
# Inspect model modules for a patched MoE layer
|
# Inspect model modules for a patched MoE layer
|
||||||
patched = None
|
patched = None
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
|
if hasattr(m, "_afb_patched") and m._afb_patched is True:
|
||||||
patched = m
|
patched = m
|
||||||
break
|
break
|
||||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import unittest
|
|||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import with_temp_dir
|
||||||
|
|||||||
@@ -4,11 +4,9 @@ E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -65,7 +63,9 @@ class TestQwen3MoeAuxFree(unittest.TestCase):
|
|||||||
# check that at least one sparse MoE block has been patched
|
# check that at least one sparse MoE block has been patched
|
||||||
found = False
|
found = False
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
|
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(
|
||||||
|
m, "_afb_patched"
|
||||||
|
):
|
||||||
assert m._afb_patched is True
|
assert m._afb_patched is True
|
||||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tbparse import SummaryReader
|
from tbparse import SummaryReader
|
||||||
except ImportError: # pragma: no cover - optional dependency
|
except ImportError: # pragma: no cover - optional dependency
|
||||||
@@ -189,7 +190,9 @@ def check_tensorboard(
|
|||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
"""
|
"""
|
||||||
if SummaryReader is None:
|
if SummaryReader is None:
|
||||||
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
raise unittest.SkipTest(
|
||||||
|
"tbparse is not installed; skipping tensorboard assertions"
|
||||||
|
)
|
||||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
|
|||||||
Reference in New Issue
Block a user