Compare commits

..

4 Commits

Author SHA1 Message Date
lhl
949cdf01eb tests: extend aux-free coverage
- add warmup, EP sync, and mixtral parity unit checks
2026-03-22 11:16:02 -04:00
lhl
a0019021dd aux_free_router: sync shim state
- drive warmup-aware bias updates and register live buffers
2026-03-22 11:16:02 -04:00
lhl
2af7475fdf Add ring/llama4 aux-free adapters and EP sync support 2026-03-22 11:16:02 -04:00
lhl
3e4688289c feat(moe-aux-loss-free): aux-free MoE plugin (Mixtral/Qwen3), EMA bias updates, config keys; E2E smoke + parity tests 2026-03-22 11:16:02 -04:00
33 changed files with 1719 additions and 3946 deletions

View File

@@ -3,8 +3,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"

View File

@@ -37,7 +37,6 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -0,0 +1,40 @@
# Aux-Loss-Free MoE Router Plugin
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
Summary
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
Enable
- Add the plugin to your YAML, then set the aux-free toggle:
plugins:
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
moe_balance_type: noaux_tc
moe_update_rate: 0.01 # default if unset
moe_update_momentum: 0.9 # default if unset
moe_bias_cap: 2.0 # default if unset
moe_afb_warmup_steps: 100 # optional
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
Config keys
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
- moe_update_rate: bias update rate (gamma). Default: 0.01.
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: future updates will log per-expert loads and bias magnitudes.

View File

@@ -0,0 +1,2 @@
"""Aux-loss-free (AFB) MoE router integration package."""

View File

@@ -0,0 +1,317 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from torch import nn
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
from .core import AuxFreeShim
LOG = get_logger(__name__)
@dataclass
class LayerHandle:
layer: nn.Module
layer_idx: int
num_experts: int
top_k: int
class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and wraps their forward.
Concrete adapters should implement discovery and per-layer attribute extraction.
"""
family: str = "generic"
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
return []
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts"))
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"):
try:
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
except Exception: # pragma: no cover - non-critical
pass
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled."""
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
"""Replace the layer's forward with an aux-free gating version.
Assumes the layer exposes attributes:
- gate: linear router projecting hidden to num_experts
- num_experts: int
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
"""
if getattr(moe_layer, "_afb_patched", False):
return
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
return
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
# hidden_states: (B, T, H)
bsz, seqlen, hdim = hidden_states.shape
hs = hidden_states.view(-1, hdim)
logits = self.gate(hs)
# selection uses biased logits; weights from unbiased logits
bias = getattr(self, "_afb_bias")
top_k = int(getattr(self, "_afb_top_k", 2))
biased = logits + bias # broadcast over tokens
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1)
weights = weights.to(hs.dtype)
# accumulate counts for bias update callback
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
# dispatch tokens to experts
hs_rep = hs.repeat_interleave(top_k, dim=0)
y = torch.empty_like(hs_rep)
for eid in range(int(self.num_experts)):
mask = flat_idx == eid
if mask.any():
y[mask] = self.experts[eid](hs_rep[mask])
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
out = y.view(bsz, seqlen, hdim)
return (out, logits)
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class MixtralAdapter(BaseMoEAdapter):
family = "mixtral"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_mixtral_forward(moe_layer, shim)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
if getattr(moe_layer, "_afb_patched", False):
return
shim_ref = shim
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and getattr(self, "jitter_noise", 0) > 0:
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
flat_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(flat_states)
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
top_k = int(getattr(self, "_afb_top_k", self.top_k))
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
routing_weights = routing_weights.to(flat_states.dtype)
flat_idx = selected_experts.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=flat_states.dtype,
device=flat_states.device,
)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
class BailingAdapter(BaseMoEAdapter):
family = "bailing_moe"
def matches(self, model: nn.Module) -> bool:
model_type = getattr(getattr(model, "config", object()), "model_type", "")
return model_type in ("bailing_moe", "bailing_moe_v2")
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
yield m
def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"):
return int(getattr(moe_layer, "num_experts"))
cfg = getattr(moe_layer, "config", None)
return int(getattr(cfg, "num_experts"))
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer)
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
def afb_gate_forward(self, hidden_states: torch.Tensor):
flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = getattr(moe_layer, "_afb_bias")
biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
weights = weights / denom
weights = weights * self.routed_scaling_factor
flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
setattr(gate, "_afb_patched", True)
class Llama4Adapter(BaseMoEAdapter):
family = "llama4"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "Llama4TextMoe":
yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer)
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
router = getattr(moe_layer, "router", None)
if router is None:
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
return
if getattr(router, "_afb_patched", False):
return
def afb_router_forward(self, hidden_states: torch.Tensor):
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
router_logits = F.linear(flat, self.weight, self.bias)
bias = getattr(moe_layer, "_afb_bias")
biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices)
router_scores = torch.full_like(router_logits, float("-inf"))
router_scores.scatter_(1, router_indices, unbiased_top)
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
setattr(router, "_afb_patched", True)
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates.
"""
handles: list[LayerHandle] = []
adapter: Optional[BaseMoEAdapter] = None
for a in adapters:
if a.matches(model):
adapter = a
break
if adapter is None:
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
return handles
# disable aux loss at model level if possible
adapter.disable_aux_loss(getattr(model, "config", model))
idx = 0
for layer in adapter.find_moe_layers(model):
try:
top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer)
except Exception:
continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
adapter.prepare(layer, handle, shim)
handles.append(handle)
idx += 1
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
return handles

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass
class AuxFreeConfig:
rate: float = 0.01
momentum: float = 0.9
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
class AuxFreeState:
"""Holds per-layer bias and EMA load buffers."""
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.cfg = cfg
self.steps = 0
class AuxFreeShim:
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
def __init__(
self,
state: AuxFreeState,
ep_group: Optional[dist.ProcessGroup] = None,
ep_size: Optional[int] = None,
):
self.state = state
self.ep_group = ep_group
self._ep_size = ep_size
self._ep_group_pending = (
self.state.cfg.sync_group == "ep" and self.ep_group is None
)
self._layer_modules: dict[int, torch.nn.Module] = {}
@torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_bias"):
b = getattr(module, "_afb_bias")
else:
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
"""Bind model buffers so shim updates stay in sync with patched layers."""
self._layer_modules[layer_idx] = module
bias = getattr(module, "_afb_bias")
ema = getattr(module, "_afb_ema")
# Keep state views pointing to the same tensors to avoid drift.
if layer_idx < len(self.state.bias):
self.state.bias[layer_idx] = bias
if layer_idx < len(self.state.ema_load):
self.state.ema_load[layer_idx] = ema
def begin_step(self) -> None:
"""Call once per optimizer step before per-layer updates."""
self.state.steps += 1
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
self._maybe_init_ep_group()
if not dist.is_available() or not dist.is_initialized():
return counts
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
return counts
@torch.no_grad()
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
cfg = self.state.cfg
if self.state.steps <= cfg.warmup_steps:
return
nE = step_counts.numel()
if tokens_seen <= 0:
return
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_ema"):
ema = getattr(module, "_afb_ema")
bias = getattr(module, "_afb_bias")
else:
ema = self.state.ema_load[layer_idx]
bias = self.state.bias[layer_idx]
counts = step_counts.to(ema.device)
freq = counts.float() / float(tokens_seen)
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
# optional mean-centering to keep sum(bias) ~ 0
delta = delta - delta.mean()
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
def _maybe_init_ep_group(self) -> None:
if not self._ep_group_pending:
return
if not dist.is_available() or not dist.is_initialized():
return
ep_size = self._ep_size
if not ep_size or ep_size <= 1:
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
ep_size,
world,
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
if ep_size == world:
self.ep_group = dist.group.WORLD
else:
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
self.ep_group = dist.new_group(ranks)
LOG.info(
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
ep_size,
dist.get_world_size(),
)
self._ep_group_pending = False

View File

@@ -0,0 +1,175 @@
"""Aux-loss-free MoE Router Plugin for Axolotl.
This plugin wires an aux-free gating option into compatible MoE models using
unbiased logits for mixture weights and per-expert biases for top-k selection.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.distributed as dist
from transformers.trainer_callback import TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .adapters import (
BailingAdapter,
BaseMoEAdapter,
Llama4Adapter,
MixtralAdapter,
Qwen3Adapter,
discover_and_prepare_layers,
)
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
LOG = get_logger(__name__)
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
"""Post-step callback to update aux-free biases from accumulated expert counts.
Note: The current revision expects per-layer counts to be accumulated on each
MoE layer as a buffer named `_afb_counts` during forward (to be added with
routing patches in a follow-up).
"""
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
self.shim = shim
self.layer_modules = layer_modules
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
self.shim.begin_step()
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
continue
counts = getattr(layer, "_afb_counts")
if counts is None:
continue
counts = self.shim.all_reduce_counts(counts)
layer_idx = getattr(layer, "_afb_layer_idx", None)
if layer_idx is None:
counts.zero_()
continue
bias = getattr(layer, "_afb_bias")
counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
# reset step counts
counts.zero_()
return control
class AuxFreeMoEPlugin(BasePlugin):
"""Plugin that enables aux-loss-free routing when configured."""
def __init__(self):
super().__init__()
self._handles: list = []
self._shim: Optional[AuxFreeShim] = None
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return
# Be conservative — skip known native aux-free families
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
"deepseek_v3",
"glm4_moe",
)
if native_auxfree:
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
return
# Build aux-free state and shim
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
momentum = (
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
)
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
af_cfg = AuxFreeConfig(
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
)
# Discover layers to count the number and experts for state sizing
adapters: list[BaseMoEAdapter] = [
MixtralAdapter(),
Qwen3Adapter(),
BailingAdapter(),
Llama4Adapter(),
]
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
n_experts = None
for m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0:
n_layers = 1
if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None
if sync_group == "ep":
if dist.is_available() and dist.is_initialized():
ep_group = self._resolve_ep_group(cfg)
else:
LOG.info(
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
)
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
LOG.info(
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
)
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized():
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
return None
ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1:
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
return None
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
ep_size,
world,
)
return None
if ep_size == world:
return dist.group.WORLD
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
return self._ep_group_cache[ranks]
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return []
if self._shim is None:
return []
# gather concrete layer modules from handles
layers = [h.layer for h in self._handles]
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -36,8 +36,6 @@ SPARSE_MOE_BLOCK = {
"glm4v_moe": "Glm4vMoeTextMoE",
# sigmoid -> topk routing (no group selection)
"minimax_m2": "MiniMaxM2SparseMoeBlock",
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
"nemotron_h": "NemotronHMoE",
# Models below need custom routing (not yet implemented):
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)

View File

@@ -168,9 +168,6 @@ def _unwrap_experts_lora(experts_module):
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
instead of ``gate_up_proj``.
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
@@ -179,7 +176,6 @@ def _unwrap_experts_lora(experts_module):
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
@@ -199,15 +195,13 @@ def _unwrap_experts_lora(experts_module):
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
for attr in ("gate_up_proj", "up_proj"):
param = getattr(base_experts, attr, None)
if param is not None:
num_experts = param.shape[0]
break
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
@@ -447,12 +441,10 @@ class HFScatterMoEGatedMLP(nn.Module):
Supports:
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
"""
@staticmethod
@@ -475,7 +467,7 @@ class HFScatterMoEGatedMLP(nn.Module):
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
# ====================================================================
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
@@ -497,22 +489,6 @@ class HFScatterMoEGatedMLP(nn.Module):
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
# ====================================================================
is_gated = hasattr(experts, "gate_up_proj")
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
# ====================================================================
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
# ====================================================================
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
expert_input = hidden_states_flat
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
expert_input = fc1_latent_proj(hidden_states_flat)
# ====================================================================
# Selective expert weight dequantization
# ====================================================================
@@ -522,7 +498,7 @@ class HFScatterMoEGatedMLP(nn.Module):
use_selective = (
getattr(self, "_use_selective_dequant", False)
and hasattr(experts, "parametrizations")
and up_proj_attr in experts.parametrizations
and "gate_up_proj" in experts.parametrizations
)
if use_selective:
@@ -541,11 +517,11 @@ class HFScatterMoEGatedMLP(nn.Module):
num_experts,
)
# Dequantize only active experts' weights
up_W = selective_expert_weights(
gate_up_W = selective_expert_weights(
experts,
up_proj_attr,
"gate_up_proj",
active_experts,
).transpose(2, 1)
).transpose(2, 1) # [num_active, hidden, 2*inter]
# Remap LoRA weights to match compact expert indices
if gup_lora is not None:
@@ -562,18 +538,18 @@ class HFScatterMoEGatedMLP(nn.Module):
sei_gup = remapped_expert_idxs
eo_gup = compact_offsets
else:
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
sei_gup = sorted_expert_idxs
eo_gup = expert_offsets
# ====================================================================
# Up projection (gated: gate_up_proj; non-gated: up_proj)
# Gate + Up projection
# ====================================================================
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
up_out = parallel_linear_lora(
expert_input,
up_W,
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -587,9 +563,9 @@ class HFScatterMoEGatedMLP(nn.Module):
use_fused_gather=True,
)
else:
up_out = parallel_linear(
expert_input,
up_W,
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
top_k,
sei_gup,
sorted_scattered_idxs,
@@ -598,14 +574,8 @@ class HFScatterMoEGatedMLP(nn.Module):
grouped_out=True,
)
# ====================================================================
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
# ====================================================================
if is_gated:
gates, h = up_out.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
else:
h = experts.act_fn(up_out)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Down projection
@@ -665,12 +635,6 @@ class HFScatterMoEGatedMLP(nn.Module):
gates=routing_weights,
)
# ====================================================================
# Optional latent projection back to hidden_size (NemotronH)
# ====================================================================
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
expert_output = fc2_latent_proj(expert_output)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================

View File

@@ -30,15 +30,6 @@ class LigerArgs(BaseModel):
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None

View File

@@ -1,175 +0,0 @@
"""
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_5(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
if rms_norm:
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward

View File

@@ -1,198 +0,0 @@
"""
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
load_balancing_loss_func,
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits,
labels,
self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def apply_liger_kernel_to_qwen3_5_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
if rms_norm:
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_mod.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward

View File

@@ -174,19 +174,6 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5":
from axolotl.integrations.liger.models.qwen3_5 import (
apply_liger_kernel_to_qwen3_5,
)
apply_liger_kernel_to_qwen3_5(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -199,19 +186,6 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5_moe":
from axolotl.integrations.liger.models.qwen3_5_moe import (
apply_liger_kernel_to_qwen3_5_moe,
)
apply_liger_kernel_to_qwen3_5_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

View File

@@ -1,147 +0,0 @@
"""
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid
materializing the full [out_features, in_features] combined weight matrix.
The B@A product is computed row-by-row inside the kernel.
"""
import torch
import triton
import triton.language as tl
from .quantize import dequantize
@triton.jit
def _dora_fused_norm_kernel(
# Pointers
W_ptr, # base weight [out, in] (dequantized, row-major)
B_ptr, # LoRA B [out, rank] (row-major)
A_ptr, # LoRA A [rank, in] (row-major)
mag_ptr, # magnitude vector [out]
out_ptr, # output mag_norm_scale [out]
# Shapes
out_features,
in_features,
rank,
# Scaling
lora_scale, # float scaling factor
# Block sizes
BLOCK_IN: tl.constexpr,
BLOCK_R: tl.constexpr, # >= rank, power of 2
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
Each program handles one output row. B[row,:] is loaded once (small),
then we tile over in_features computing the dot product with A[:,tile]
and accumulating the squared norm.
This avoids materializing the full [out, in] B@A matrix.
"""
row = tl.program_id(0)
if row >= out_features:
return
# Accumulate squared norm across tiles of in_features
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
for start in range(0, in_features, BLOCK_IN):
cols = start + tl.arange(0, BLOCK_IN)
col_mask = cols < in_features
# Load W[row, cols]
w_vals = tl.load(
W_ptr + row * in_features + cols,
mask=col_mask,
other=0.0,
).to(tl.float32)
# Compute (B[row,:] @ A[:, cols]) for this tile
# Load B[row, r] as scalar and A[r, cols] as vector for each r
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
# Load scalar B[row, r]
b_val = tl.load(
B_ptr + row * rank + r,
mask=(r < rank),
other=0.0,
).to(tl.float32)
# Load vector A[r, cols]
a_vals = tl.load(
A_ptr + r * in_features + cols,
mask=(col_mask & (r < rank)),
other=0.0,
).to(tl.float32)
ba_vals += b_val * a_vals
# Combined: W + s * (B @ A)
combined = w_vals + lora_scale * ba_vals
# Accumulate squared values
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
# Reduce to scalar norm
norm_sq = tl.sum(norm_sq_acc, axis=0)
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
# Load magnitude and compute scale
mag = tl.load(mag_ptr + row).to(tl.float32)
scale = mag / norm
tl.store(out_ptr + row, scale)
def triton_dora_scale(
W: torch.Tensor,
W_quant,
A: torch.Tensor,
B: torch.Tensor,
s: float,
magnitude: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full
[out_features, in_features] materialization.
Args:
W: base weight [out, in] (possibly quantized)
W_quant: quantization state
A: LoRA A [rank, in]
B: LoRA B [out, rank]
s: LoRA scaling factor
magnitude: learned magnitude [out]
dtype: compute dtype
Returns:
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
"""
# Dequantize W to [out, in]
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
out_features, in_features = W_full.shape
rank = A.shape[0]
out = torch.empty(out_features, dtype=dtype, device=W.device)
# Block sizes
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
BLOCK_R = triton.next_power_of_2(rank)
_dora_fused_norm_kernel[(out_features,)](
W_full,
B.contiguous().to(dtype),
A.contiguous().to(dtype),
magnitude.contiguous(),
out,
out_features=out_features,
in_features=in_features,
rank=rank,
lora_scale=s,
BLOCK_IN=BLOCK_IN,
BLOCK_R=BLOCK_R,
)
return out.detach()

File diff suppressed because it is too large Load Diff

View File

@@ -105,10 +105,6 @@ def dequantize(
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
# Non-double-quantized models have offset=None and state2=None
if quant_state.offset is None or quant_state.state2 is None:
# Fall back to bitsandbytes standard dequantize
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype

View File

@@ -1,333 +0,0 @@
"""
Fused RMSNorm + SiLU Gate Triton kernel.
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
and silu(G) = G * sigmoid(G)
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_gated_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
BLOCK_SIZE: tl.constexpr,
):
"""
Y = (W + offset) * (X / RMS(X)) * silu(G)
All computation done in fp32 (Gemma-style), result cast to input dtype.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
# Cast everything to fp32
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
W_fp32 = W_row.to(tl.float32)
# RMS norm
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
X_norm = X_fp32 * rstd
# SiLU gate: silu(G) = G * sigmoid(G)
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# Fused output
Y_row = (offset + W_fp32) * X_norm * silu_G
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_row_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_gated_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
dG_ptr,
dG_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
dW = sum_batch(dY * X_norm * silu(G))
dG = dY * (W + offset) * X_norm * silu'(G)
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
where m = dY * (W + offset) * silu(G)
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row.to(tl.float32) + offset
for row_idx in range(row_start, row_end):
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
)
G_row = tl.load(
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
)
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
# Cast to fp32
dY_fp32 = dY_row.to(tl.float32)
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
# Recompute intermediates
X_norm = X_fp32 * rstd_row
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# dW: accumulate dY * X_norm * silu(G)
dW_acc += dY_fp32 * X_norm * silu_G
# dG: dY * (W + offset) * X_norm * silu'(G)
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
tl.store(
dG_ptr + row_idx * dG_row_stride + col_offsets,
dG_row.to(X_dtype),
mask=mask,
)
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
m = dY_fp32 * W_row * silu_G
dX_row = rstd_row * m
dX_row += rstd_row * (
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_gated_forward(X, G, W, eps, offset):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
G = G.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
assert X.shape[1] == W.shape[0], (
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
)
assert X.shape == G.shape, (
f"X and G must have same shape, got {X.shape} and {G.shape}"
)
_rms_norm_gated_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
dX = torch.empty_like(dY)
dG = torch.empty_like(dY)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
_rms_norm_gated_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
dG,
dG.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dX.view(*shape)
dG = dG.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dG, dW
class FusedRMSNormGatedFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, G, W, eps, offset=0.0):
"""
X: (B, T, H) or (BxT, H) — input hidden states
G: (B, T, H) or (BxT, H) — gate tensor
W: (H,) — weight parameter
"""
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
X, G, W, eps, offset
)
ctx.offset = offset
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, G, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, G, W, RSTD = ctx.saved_tensors
dX, dG, dW = rms_norm_gated_backward(
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
)
return dX, dG, dW, None, None
class FusedRMSNormGated(torch.nn.Module):
"""
Fused RMSNorm + SiLU Gate.
Computes: Y = W * RMSNorm(X) * silu(G)
Drop-in replacement for Qwen3_5RMSNormGated with matching
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
and forward signature: forward(hidden_states, gate=None)
"""
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states, gate=None):
if gate is None:
raise ValueError("FusedRMSNormGated requires a gate tensor")
if hidden_states.device.type != "cuda":
raise ValueError(
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
)
return FusedRMSNormGatedFunction.apply(
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@@ -12,7 +12,6 @@ from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_embedding,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
@@ -371,13 +370,13 @@ def apply_lora_kernel_patches(
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Log what features are active
if lora_config.lora_dropout > 0:
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
if lora_config.bias != "none":
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
if lora_config.use_dora:
LOG.info("LoRA kernels: DoRA enabled")
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
@@ -420,33 +419,44 @@ def apply_lora_kernel_patches(
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A") for module in layer_modules
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA adapters"
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
can_patch_o = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA adapters"
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
hasattr(proj, "lora_A")
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
@@ -454,50 +464,15 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters"
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
)
# Patch embedding layers (model-level, not per-layer)
if cfg.lora_embedding_kernel:
_patch_embedding_layers(model, cfg)
LOG.setLevel(original_level)
return model
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
"""Patch embedding layers with fused LoRA kernel.
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
"""
pretrained_model = model.model
patched = 0
# Find embedding modules - check common locations
for attr_path in [
("model", "embed_tokens"),
("model", "language_model", "embed_tokens"),
]:
parent = pretrained_model
for attr in attr_path:
parent = getattr(parent, attr, None)
if parent is None:
break
if parent is not None and hasattr(parent, "lora_embedding_A"):
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
parent.forward = types.MethodType(apply_lora_embedding, parent)
patched += 1
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
# when included in target_modules. No special embedding handling needed since
# PEFT wraps it as a Linear (not Embedding) even for tied models.
if not patched:
LOG.debug("No embedding layers with LoRA found to patch")
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching

View File

@@ -703,12 +703,6 @@ class AxolotlInputConfig(
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_embedding_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
@@ -764,6 +758,44 @@ class AxolotlInputConfig(
llama4_linearized_experts: bool | None = None
# MoE aux-loss-free (AFB) toggles
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. If unset, plugin default is 0.01.",
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). If unset, plugin default is 0.9.",
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
},
)
deepspeed: str | dict[str, Any] | None = Field(
default=None,
json_schema_extra={
@@ -842,6 +874,12 @@ class AxolotlInputConfig(
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
expert_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={
@@ -1319,7 +1357,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
or data.get("lora_embedding_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
@@ -1367,12 +1404,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
"lora_o_kernel",
"lora_embedding_kernel",
]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
@@ -1385,38 +1417,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip auto-enable for MoE models when native grouped_mm is unavailable
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
# with out= which bypasses autocast and fails on mixed dtypes during eval.
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
if not has_grouped_mm:
is_moe = False
model_type = data.get("model_config_type", "")
if model_type and "moe" in model_type.lower():
is_moe = True
if not is_moe:
try:
from transformers import AutoConfig
base_model = data.get("base_model")
if base_model:
auto_cfg = AutoConfig.from_pretrained(
base_model, trust_remote_code=False
)
if getattr(auto_cfg, "num_local_experts", None) or getattr(
auto_cfg, "num_experts", None
):
is_moe = True
except Exception: # pylint: disable=broad-exception-caught
pass
if is_moe:
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
@@ -1439,9 +1442,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
if data.get("lora_embedding_kernel") is None:
data["lora_embedding_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "

View File

@@ -681,7 +681,15 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernels_dora(cls, data):
# DoRA is now supported by lora kernels
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
return data
@model_validator(mode="before")
@@ -1378,6 +1386,14 @@ class ComplexValidationMixin:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after")
def check_expert_parallel_size(self):
if not getattr(self, "expert_parallel_size", None):
self.expert_parallel_size = 1
elif self.expert_parallel_size < 1:
raise ValueError("expert_parallel_size must be >= 1")
return self
@model_validator(mode="after")
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:

View File

@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
proj.base_layer = base_layer
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
W, b, quant_state, A, B, s = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
W, b, quant_state, A, B, s = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)

View File

@@ -102,7 +102,7 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
W, b, _, A, B, s = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
@@ -176,31 +176,24 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
X.requires_grad = True
output = LoRA_MLP.apply(
X,
None, # X_drop
gate_proj.weight,
gate_proj.bias,
None, # gate_quant
None, # gate_A
None, # gate_B
None, # gate_scale
None, # gate_lora_bias
None, # gate_magnitude
up_proj.weight,
up_proj.bias,
None, # up_quant
None, # up_A
None, # up_B
None, # up_scale
None, # up_lora_bias
None, # up_magnitude
down_proj.weight,
down_proj.bias,
None, # down_quant
None, # down_A
None, # down_B
None, # down_scale
None, # down_lora_bias
None, # down_magnitude
activation_forward,
activation_backward,
True, # inplace
@@ -254,31 +247,24 @@ def test_lora_mlp_with_adapters(
# Forward pass with adapters
output = LoRA_MLP.apply(
X,
None, # X_drop
gate_proj.weight,
gate_proj.bias,
None,
gate_A,
gate_B,
scale,
None, # gate_lora_bias
None, # gate_magnitude
up_proj.weight,
up_proj.bias,
None,
up_A,
up_B,
scale,
None, # up_lora_bias
None, # up_magnitude
down_proj.weight,
down_proj.bias,
None,
down_A,
down_B,
scale,
None, # down_lora_bias
None, # down_magnitude
activation_forward,
activation_backward,
True,
@@ -348,32 +334,25 @@ def test_lora_qkv(sample_tensors):
Q1, K1, V1 = LoRA_QKV.apply(
X,
None, # X_drop
q_weight,
None,
None,
None,
None,
None,
None,
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
k_weight,
None,
None,
None,
None,
None,
None,
None, # K
v_weight,
None,
None,
None,
None,
None,
None,
None, # V
True, # inplace
True,
)
assert Q1.shape == K1.shape == V1.shape == X.shape
@@ -387,32 +366,25 @@ def test_lora_qkv(sample_tensors):
# Test with LoRA adapters
Q2, K2, V2 = LoRA_QKV.apply(
X,
None, # X_drop
q_weight,
None,
None,
q_A,
q_B,
scale,
None,
None, # Q
k_weight,
None,
None,
k_A,
k_B,
scale,
None,
None, # K
v_weight,
None,
None,
v_A,
v_B,
scale,
None,
None, # V
True, # inplace
True,
)
assert Q2.shape == K2.shape == V2.shape == X.shape
@@ -455,9 +427,7 @@ def test_lora_o(sample_tensors):
# Test forward pass
X.requires_grad = True
output = LoRA_O.apply(
X, None, W, b, None, A, B, scale, None, None
) # X_drop, ..., lora_bias, magnitude
output = LoRA_O.apply(X, W, b, None, A, B, scale)
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
@@ -572,7 +542,6 @@ def test_inplace_operations(sample_tensors, apply_function):
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
device="cuda", dtype=torch.float16
),
"training": False,
},
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,120 +0,0 @@
"""Test LoRA kernels under FSDP2 multi-GPU training.
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
lora_embedding_kernel work correctly with FSDP2 sharding, including
with bias, dropout, and DoRA enabled.
"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def _run_training(temp_dir, cfg):
"""Write config and launch multi-GPU training."""
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "Qwen/Qwen3-0.6B",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:1%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
# Enable all LoRA kernels
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_embedding_kernel": True,
"save_safetensors": True,
}
cfg.update(overrides)
return DictDefault(cfg)
class TestFSDP2LoRAKernels:
"""Test LoRA kernels under FSDP2."""
@require_torch_2_7_0
def test_lora_kernels_basic(self, temp_dir):
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
cfg = _base_lora_fsdp2_config(temp_dir)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dropout(self, temp_dir):
"""LoRA kernels + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora(self, temp_dir):
"""LoRA kernels + DoRA + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
"""LoRA kernels + DoRA + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(
temp_dir,
peft_use_dora=True,
lora_dropout=0.05,
)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
def test_kernel_patch_conditions():
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout — kernels now support this
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
"lora_dropout": 0.1,
"bias": "none",
},
# Bias — kernels now support this
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -252,14 +252,13 @@ def test_kernel_patch_conditions():
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify patches ARE applied (dropout and bias are now supported)
assert (
layer.forward.__func__ is apply_lora_mlp_swiglu
or layer.forward.__func__ is apply_lora_mlp_geglu
)
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_config_options():
@@ -512,7 +511,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero DOES patch (now supported)."""
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -547,18 +546,31 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patches — should succeed even with dropout > 0
# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch()
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch(model)
# Verify patches WERE applied (dropout is now supported by kernels)
# Verify patch was not applied
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert hasattr(self_attn, "apply_qkv")
assert hasattr(self_attn, "apply_o")
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")

View File

@@ -0,0 +1,79 @@
"""
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestMoeAuxFree(unittest.TestCase):
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
@with_temp_dir
def test_mixtral_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# Inspect model modules for a patched MoE layer
patched = None
for m in model.modules():
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
patched = m
break
assert patched is not None, "No MoE layer patched by aux-free plugin"
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
# ensure counts buffer got reset by callback (best effort)
assert torch.all(patched._afb_counts == 0)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,83 @@
"""
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
def _last_logged_loss(trainer) -> float | None:
# Scan log_history for the most recent entry with a 'loss' key
for entry in reversed(trainer.state.log_history):
if isinstance(entry, dict) and "loss" in entry:
return float(entry["loss"])
return None
class TestMoeAuxParity(unittest.TestCase):
@with_temp_dir
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
base_cfg = {
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 8,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
"seed": 42,
"logging_steps": 1,
}
# Baseline: aux-loss (gshard)
cfg0 = DictDefault(dict(base_cfg))
cfg0.output_dir = f"{temp_dir}/baseline"
cfg0 = validate_config(cfg0)
normalize_config(cfg0)
# baseline uses default aux-loss routing; no plugin registration
dataset_meta0 = load_datasets(cfg=cfg0)
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"
cfg1.plugins = [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
]
cfg1.moe_balance_type = "noaux_tc"
cfg1.moe_update_rate = 0.01
cfg1.moe_update_momentum = 0.9
cfg1.moe_bias_cap = 2.0
cfg1 = validate_config(cfg1)
normalize_config(cfg1)
prepare_plugins(cfg1)
dataset_meta1 = load_datasets(cfg=cfg1)
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
loss1 = _last_logged_loss(trainer1)
assert loss1 is not None
# Assert aux-free loss is within 10% of aux-loss baseline
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"

View File

@@ -0,0 +1,76 @@
"""
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestQwen3MoeAuxFree(unittest.TestCase):
@with_temp_dir
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# check that at least one sparse MoE block has been patched
found = False
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
assert m._afb_patched is True
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
found = True
break
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
check_model_output_exists(temp_dir, cfg)

View File

@@ -12,7 +12,10 @@ from pathlib import Path
import torch
from packaging import version
from tbparse import SummaryReader
try:
from tbparse import SummaryReader
except ImportError: # pragma: no cover - optional dependency
SummaryReader = None
from axolotl.utils.dict import DictDefault
@@ -179,12 +182,14 @@ def check_tensorboard(
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.05,
rtol: float = 0.02,
gt_zero: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
"""
if SummaryReader is None:
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)

View File

@@ -1,229 +0,0 @@
"""
Correctness tests for fused RMSNorm + SiLU Gate kernel.
Tests against the eager Qwen3_5RMSNormGated implementation.
"""
import pytest
import torch
import torch.nn.functional as F
pytest.importorskip("triton", reason="triton required for fused kernels")
if not torch.cuda.is_available():
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
class EagerRMSNormGated(torch.nn.Module):
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _sync_weights(eager_mod, fused_mod):
"""Copy weights from eager to fused module."""
fused_mod.weight.data.copy_(eager_mod.weight.data)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 128, 256),
(4, 64, 512),
(1, 32, 1024),
(2, 16, 2560), # Qwen3.5-4B hidden_size
(2, 16, 4096), # Qwen3.5-9B hidden_size
(1, 8, 5120), # Qwen3.5-27B hidden_size
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
],
)
class TestRMSNormGatedForward:
def test_output_matches_eager(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
if dtype == torch.float32:
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
else:
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
def test_output_shape(self, dtype, shape):
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
y = fused(X, gate=G)
assert y.shape == (B, T, H)
assert y.dtype == dtype
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 32, 256),
(2, 16, 512),
(2, 16, 2560), # Qwen3.5-4B
(1, 8, 4096), # Qwen3.5-9B
(1, 8, 5120), # Qwen3.5-27B
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
],
)
class TestRMSNormGatedBackward:
def test_grad_x(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
def test_grad_gate(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
def test_grad_weight(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
)
class TestRMSNormGatedEdgeCases:
def test_gate_none_raises(self):
fused = FusedRMSNormGated(256).cuda()
X = torch.randn(2, 4, 256, device="cuda")
with pytest.raises(ValueError, match="requires a gate tensor"):
fused(X, gate=None)
def test_2d_input(self):
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
torch.manual_seed(42)
H = 512
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
def test_random_weight_init(self):
"""Test with non-default weight values."""
torch.manual_seed(123)
H = 256
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
# Randomize weights
eager.weight.data = torch.randn_like(eager.weight.data)
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)

View File

@@ -0,0 +1,267 @@
import os
import sys
import tempfile
import unittest
from types import SimpleNamespace
import torch
import torch.distributed as dist
import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
def _cfg(**overrides):
defaults = dict(
moe_balance_type="noaux_tc",
moe_update_rate=0.1,
moe_update_momentum=0.9,
moe_bias_cap=2.0,
moe_afb_warmup_steps=0,
moe_bias_sync_group="world",
expert_parallel_size=1,
)
defaults.update(overrides)
return SimpleNamespace(**defaults)
def _load_bailing_modules():
repo_dir = snapshot_download(
repo_id="inclusionAI/Ring-mini-2.0",
allow_patterns=[
"configuration_bailing_moe_v2.py",
"modeling_bailing_moe_v2.py",
"__init__.py",
],
)
repo = Path(repo_dir)
config_path = repo / "configuration_bailing_moe_v2.py"
modeling_path = repo / "modeling_bailing_moe_v2.py"
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
if config_name not in sys.modules:
spec = importlib_util.spec_from_file_location(config_name, config_path)
module = importlib_util.module_from_spec(spec)
sys.modules[config_name] = module
sys.modules["configuration_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
config_module = sys.modules[config_name]
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
if modeling_name not in sys.modules:
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
module = importlib_util.module_from_spec(spec)
sys.modules[modeling_name] = module
sys.modules["modeling_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
modeling_module = sys.modules[modeling_name]
BailingMoeV2Config = config_module.BailingMoeV2Config
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
def _build_bailing_model():
BailingConfig, BailingBlock = _load_bailing_modules()
config = BailingConfig(
hidden_size=16,
intermediate_size=32,
moe_intermediate_size=32,
num_experts=4,
num_shared_experts=None,
num_experts_per_tok=2,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
)
block = BailingBlock(config)
class DummyModel(nn.Module):
def __init__(self, layer):
super().__init__()
self.block = layer
self.config = SimpleNamespace(model_type="bailing_moe")
def forward(self, hidden_states):
return self.block(hidden_states)
return DummyModel(block), block
def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
)
layer = Llama4TextMoe(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="llama4")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _build_mixtral_model():
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
config = MixtralConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
)
layer = MixtralSparseMoeBlock(config)
layer.config = config
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="mixtral")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
dummy = SimpleNamespace()
callback.on_step_end(args=dummy, state=dummy, control=dummy)
class TestAuxFreeAdapters(unittest.TestCase):
def test_bailing_adapter_updates_counts_and_bias(self):
model, block = _build_bailing_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(block, "_afb_bias"))
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(layer, "_afb_bias"))
hidden = torch.randn(2, 4, layer.hidden_dim)
layer(hidden)
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
cfg = _cfg(moe_afb_warmup_steps=2)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone())
self.assertTrue(torch.allclose(baseline_out, patched_out))
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
if __name__ == "__main__":
unittest.main()

View File

@@ -28,22 +28,20 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["adapter"] == "lora"
# DoRA is now compatible with lora kernels
dora_kernel_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(dora_kernel_config)
assert result["lora_mlp_kernel"] is True
assert result["peft_use_dora"] is True
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""

View File

@@ -38,11 +38,6 @@ class TestLoRAParameterFreezing:
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
mock_layer.lora_B["default"].bias = None
# Required by get_lora_parameters for dropout/DoRA extraction
mock_layer.lora_dropout = {}
mock_layer.lora_magnitude_vector = None
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
@@ -53,7 +48,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -67,7 +62,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -82,7 +77,7 @@ class TestLoRAParameterFreezing:
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -99,7 +94,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
@@ -115,7 +110,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
@@ -129,7 +124,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
@@ -143,7 +138,7 @@ class TestLoRAParameterFreezing:
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert quant_state == quant_state_mock
@@ -162,7 +157,7 @@ class TestLoRAParameterFreezing:
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
@@ -197,13 +192,13 @@ class TestLoRAParameterFreezingIntegration:
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None