Compare commits
5 Commits
949cdf01eb
...
scattermoe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
936149380f | ||
|
|
86be9f329e | ||
|
|
0e583efeaa | ||
|
|
b3289fd190 | ||
|
|
a67392c427 |
@@ -3,7 +3,8 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
set -o pipefail
|
||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
|
||||
@@ -37,6 +37,7 @@ coverage:
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
informational: true
|
||||
|
||||
parsers:
|
||||
gcov:
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
# Aux-Loss-Free MoE Router Plugin
|
||||
|
||||
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
||||
|
||||
Summary
|
||||
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
||||
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
||||
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
||||
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
||||
|
||||
Enable
|
||||
- Add the plugin to your YAML, then set the aux-free toggle:
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
||||
|
||||
moe_balance_type: noaux_tc
|
||||
moe_update_rate: 0.01 # default if unset
|
||||
moe_update_momentum: 0.9 # default if unset
|
||||
moe_bias_cap: 2.0 # default if unset
|
||||
moe_afb_warmup_steps: 100 # optional
|
||||
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
||||
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
||||
|
||||
Config keys
|
||||
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
||||
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
||||
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
||||
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
||||
|
||||
Compatibility
|
||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
||||
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
||||
|
||||
Notes
|
||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
||||
- Telemetry: future updates will log per-expert loads and bias magnitudes.
|
||||
@@ -1,2 +0,0 @@
|
||||
"""Aux-loss-free (AFB) MoE router integration package."""
|
||||
|
||||
@@ -1,317 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .core import AuxFreeShim
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerHandle:
|
||||
layer: nn.Module
|
||||
layer_idx: int
|
||||
num_experts: int
|
||||
top_k: int
|
||||
|
||||
|
||||
class BaseMoEAdapter:
|
||||
"""Base adapter that discovers MoE layers and wraps their forward.
|
||||
|
||||
Concrete adapters should implement discovery and per-layer attribute extraction.
|
||||
"""
|
||||
|
||||
family: str = "generic"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
||||
return False
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
|
||||
return []
|
||||
|
||||
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
||||
return int(getattr(moe_layer, "num_experts"))
|
||||
|
||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
||||
# Best-effort: zero router aux loss coef if present
|
||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
||||
try:
|
||||
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
|
||||
except Exception: # pragma: no cover - non-critical
|
||||
pass
|
||||
|
||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
||||
if not hasattr(moe_layer, "_afb_bias"):
|
||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
||||
if not hasattr(moe_layer, "_afb_counts"):
|
||||
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
|
||||
if not hasattr(moe_layer, "_afb_ema"):
|
||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_forward_with_aux_free(moe_layer)
|
||||
|
||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
||||
"""Replace the layer's forward with an aux-free gating version.
|
||||
|
||||
Assumes the layer exposes attributes:
|
||||
- gate: linear router projecting hidden to num_experts
|
||||
- num_experts: int
|
||||
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
|
||||
"""
|
||||
if getattr(moe_layer, "_afb_patched", False):
|
||||
return
|
||||
|
||||
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
|
||||
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
|
||||
return
|
||||
|
||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||
# hidden_states: (B, T, H)
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
hs = hidden_states.view(-1, hdim)
|
||||
logits = self.gate(hs)
|
||||
# selection uses biased logits; weights from unbiased logits
|
||||
bias = getattr(self, "_afb_bias")
|
||||
top_k = int(getattr(self, "_afb_top_k", 2))
|
||||
biased = logits + bias # broadcast over tokens
|
||||
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1)
|
||||
weights = weights.to(hs.dtype)
|
||||
|
||||
# accumulate counts for bias update callback
|
||||
flat_idx = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
|
||||
|
||||
# dispatch tokens to experts
|
||||
hs_rep = hs.repeat_interleave(top_k, dim=0)
|
||||
y = torch.empty_like(hs_rep)
|
||||
for eid in range(int(self.num_experts)):
|
||||
mask = flat_idx == eid
|
||||
if mask.any():
|
||||
y[mask] = self.experts[eid](hs_rep[mask])
|
||||
|
||||
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
|
||||
out = y.view(bsz, seqlen, hdim)
|
||||
return (out, logits)
|
||||
|
||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||
setattr(moe_layer, "_afb_patched", True)
|
||||
|
||||
|
||||
class MixtralAdapter(BaseMoEAdapter):
|
||||
family = "mixtral"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_mixtral_forward(moe_layer, shim)
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
||||
yield m
|
||||
|
||||
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
||||
if getattr(moe_layer, "_afb_patched", False):
|
||||
return
|
||||
|
||||
shim_ref = shim
|
||||
|
||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
||||
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
||||
)
|
||||
flat_states = hidden_states.view(-1, hidden_dim)
|
||||
router_logits = self.gate(flat_states)
|
||||
|
||||
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
||||
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
||||
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
||||
routing_weights = routing_weights.to(flat_states.dtype)
|
||||
|
||||
flat_idx = selected_experts.reshape(-1)
|
||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
||||
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=flat_states.dtype,
|
||||
device=flat_states.device,
|
||||
)
|
||||
|
||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
for expert_idx in expert_hit:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
||||
setattr(moe_layer, "_afb_patched", True)
|
||||
|
||||
|
||||
class Qwen3Adapter(MixtralAdapter):
|
||||
family = "qwen3_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
||||
|
||||
|
||||
class BailingAdapter(BaseMoEAdapter):
|
||||
family = "bailing_moe"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
model_type = getattr(getattr(model, "config", object()), "model_type", "")
|
||||
return model_type in ("bailing_moe", "bailing_moe_v2")
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
||||
yield m
|
||||
|
||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
||||
if hasattr(moe_layer, "num_experts"):
|
||||
return int(getattr(moe_layer, "num_experts"))
|
||||
cfg = getattr(moe_layer, "config", None)
|
||||
return int(getattr(cfg, "num_experts"))
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_bailing_gate(moe_layer)
|
||||
|
||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
||||
gate = getattr(moe_layer, "gate", None)
|
||||
if gate is None:
|
||||
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
||||
return
|
||||
if getattr(gate, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
logits = F.linear(flat.float(), self.weight.float())
|
||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
biased_scores = scores_unbiased + bias
|
||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
||||
if self.top_k > 1:
|
||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
||||
weights = weights / denom
|
||||
weights = weights * self.routed_scaling_factor
|
||||
|
||||
flat_topk = topk_idx.reshape(-1)
|
||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
||||
|
||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
||||
setattr(gate, "_afb_patched", True)
|
||||
|
||||
|
||||
class Llama4Adapter(BaseMoEAdapter):
|
||||
family = "llama4"
|
||||
|
||||
def matches(self, model: nn.Module) -> bool:
|
||||
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
||||
|
||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "Llama4TextMoe":
|
||||
yield m
|
||||
|
||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
||||
self._register_aux_buffers(moe_layer, handle, shim)
|
||||
self._patch_llama4_router(moe_layer)
|
||||
|
||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
||||
router = getattr(moe_layer, "router", None)
|
||||
if router is None:
|
||||
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
||||
return
|
||||
if getattr(router, "_afb_patched", False):
|
||||
return
|
||||
|
||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
||||
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
||||
router_logits = F.linear(flat, self.weight, self.bias)
|
||||
bias = getattr(moe_layer, "_afb_bias")
|
||||
biased_logits = router_logits + bias
|
||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
||||
router_scores = torch.full_like(router_logits, float("-inf"))
|
||||
router_scores.scatter_(1, router_indices, unbiased_top)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
||||
|
||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
||||
|
||||
return router_scores, router_logits
|
||||
|
||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
||||
setattr(router, "_afb_patched", True)
|
||||
|
||||
|
||||
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
||||
|
||||
Returns a list of layer handles for later routing patching and updates.
|
||||
"""
|
||||
handles: list[LayerHandle] = []
|
||||
adapter: Optional[BaseMoEAdapter] = None
|
||||
for a in adapters:
|
||||
if a.matches(model):
|
||||
adapter = a
|
||||
break
|
||||
|
||||
if adapter is None:
|
||||
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
||||
return handles
|
||||
|
||||
# disable aux loss at model level if possible
|
||||
adapter.disable_aux_loss(getattr(model, "config", model))
|
||||
|
||||
idx = 0
|
||||
for layer in adapter.find_moe_layers(model):
|
||||
try:
|
||||
top_k = adapter.get_top_k(layer)
|
||||
nE = adapter.get_num_experts(layer)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
||||
adapter.prepare(layer, handle, shim)
|
||||
handles.append(handle)
|
||||
idx += 1
|
||||
|
||||
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
|
||||
return handles
|
||||
@@ -1,150 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuxFreeConfig:
|
||||
rate: float = 0.01
|
||||
momentum: float = 0.9
|
||||
bias_cap: float = 2.0
|
||||
warmup_steps: int = 0
|
||||
sync_group: str = "world" # or "ep"
|
||||
|
||||
|
||||
class AuxFreeState:
|
||||
"""Holds per-layer bias and EMA load buffers."""
|
||||
|
||||
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
|
||||
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
||||
self.cfg = cfg
|
||||
self.steps = 0
|
||||
|
||||
|
||||
class AuxFreeShim:
|
||||
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state: AuxFreeState,
|
||||
ep_group: Optional[dist.ProcessGroup] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
):
|
||||
self.state = state
|
||||
self.ep_group = ep_group
|
||||
self._ep_size = ep_size
|
||||
self._ep_group_pending = (
|
||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
||||
)
|
||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_bias"):
|
||||
b = getattr(module, "_afb_bias")
|
||||
else:
|
||||
b = self.state.bias[layer_idx]
|
||||
biased = logits + b # bias is a buffer
|
||||
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
||||
return topk_idx, weights
|
||||
|
||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
||||
self._layer_modules[layer_idx] = module
|
||||
bias = getattr(module, "_afb_bias")
|
||||
ema = getattr(module, "_afb_ema")
|
||||
# Keep state views pointing to the same tensors to avoid drift.
|
||||
if layer_idx < len(self.state.bias):
|
||||
self.state.bias[layer_idx] = bias
|
||||
if layer_idx < len(self.state.ema_load):
|
||||
self.state.ema_load[layer_idx] = ema
|
||||
|
||||
def begin_step(self) -> None:
|
||||
"""Call once per optimizer step before per-layer updates."""
|
||||
self.state.steps += 1
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
||||
self._maybe_init_ep_group()
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return counts
|
||||
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
||||
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
||||
return counts
|
||||
|
||||
@torch.no_grad()
|
||||
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
||||
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
||||
cfg = self.state.cfg
|
||||
if self.state.steps <= cfg.warmup_steps:
|
||||
return
|
||||
|
||||
nE = step_counts.numel()
|
||||
if tokens_seen <= 0:
|
||||
return
|
||||
module = self._layer_modules.get(layer_idx)
|
||||
if module is not None and hasattr(module, "_afb_ema"):
|
||||
ema = getattr(module, "_afb_ema")
|
||||
bias = getattr(module, "_afb_bias")
|
||||
else:
|
||||
ema = self.state.ema_load[layer_idx]
|
||||
bias = self.state.bias[layer_idx]
|
||||
counts = step_counts.to(ema.device)
|
||||
freq = counts.float() / float(tokens_seen)
|
||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
||||
target = 1.0 / float(nE)
|
||||
delta = cfg.rate * (target - ema)
|
||||
# optional mean-centering to keep sum(bias) ~ 0
|
||||
delta = delta - delta.mean()
|
||||
bias.add_(delta)
|
||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
||||
|
||||
def _maybe_init_ep_group(self) -> None:
|
||||
if not self._ep_group_pending:
|
||||
return
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
return
|
||||
ep_size = self._ep_size
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
self.ep_group = dist.group.WORLD
|
||||
self._ep_group_pending = False
|
||||
return
|
||||
if ep_size == world:
|
||||
self.ep_group = dist.group.WORLD
|
||||
else:
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
self.ep_group = dist.new_group(ranks)
|
||||
LOG.info(
|
||||
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
||||
ep_size,
|
||||
dist.get_world_size(),
|
||||
)
|
||||
self._ep_group_pending = False
|
||||
@@ -1,175 +0,0 @@
|
||||
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
||||
|
||||
This plugin wires an aux-free gating option into compatible MoE models using
|
||||
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .adapters import (
|
||||
BailingAdapter,
|
||||
BaseMoEAdapter,
|
||||
Llama4Adapter,
|
||||
MixtralAdapter,
|
||||
Qwen3Adapter,
|
||||
discover_and_prepare_layers,
|
||||
)
|
||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
||||
"""Post-step callback to update aux-free biases from accumulated expert counts.
|
||||
|
||||
Note: The current revision expects per-layer counts to be accumulated on each
|
||||
MoE layer as a buffer named `_afb_counts` during forward (to be added with
|
||||
routing patches in a follow-up).
|
||||
"""
|
||||
|
||||
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
|
||||
self.shim = shim
|
||||
self.layer_modules = layer_modules
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
||||
# Iterate prepared MoE layers and apply the bias update rule.
|
||||
self.shim.begin_step()
|
||||
for layer in self.layer_modules:
|
||||
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
||||
continue
|
||||
counts = getattr(layer, "_afb_counts")
|
||||
if counts is None:
|
||||
continue
|
||||
counts = self.shim.all_reduce_counts(counts)
|
||||
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
||||
if layer_idx is None:
|
||||
counts.zero_()
|
||||
continue
|
||||
bias = getattr(layer, "_afb_bias")
|
||||
counts_for_update = counts.to(bias.device)
|
||||
tokens_seen = int(counts_for_update.sum().item())
|
||||
# local layer-state EMA and bias update
|
||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
||||
# reset step counts
|
||||
counts.zero_()
|
||||
return control
|
||||
|
||||
|
||||
class AuxFreeMoEPlugin(BasePlugin):
|
||||
"""Plugin that enables aux-loss-free routing when configured."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._handles: list = []
|
||||
self._shim: Optional[AuxFreeShim] = None
|
||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
||||
|
||||
def post_model_build(self, cfg, model):
|
||||
# Enable only when explicitly requested
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return
|
||||
|
||||
# Be conservative — skip known native aux-free families
|
||||
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
|
||||
"deepseek_v3",
|
||||
"glm4_moe",
|
||||
)
|
||||
if native_auxfree:
|
||||
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
|
||||
return
|
||||
|
||||
# Build aux-free state and shim
|
||||
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
||||
momentum = (
|
||||
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
||||
)
|
||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
||||
af_cfg = AuxFreeConfig(
|
||||
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
|
||||
)
|
||||
|
||||
# Discover layers to count the number and experts for state sizing
|
||||
adapters: list[BaseMoEAdapter] = [
|
||||
MixtralAdapter(),
|
||||
Qwen3Adapter(),
|
||||
BailingAdapter(),
|
||||
Llama4Adapter(),
|
||||
]
|
||||
|
||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
||||
n_layers = 0
|
||||
n_experts = None
|
||||
for m in model.modules():
|
||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
||||
device = next(model.parameters(), torch.tensor(0)).device
|
||||
if n_layers <= 0:
|
||||
n_layers = 1
|
||||
if n_experts is None:
|
||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
||||
n_experts = 2
|
||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
ep_group = None
|
||||
if sync_group == "ep":
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
ep_group = self._resolve_ep_group(cfg)
|
||||
else:
|
||||
LOG.info(
|
||||
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
||||
)
|
||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
||||
|
||||
# Discover and prepare layers (attach per-layer buffers)
|
||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
||||
|
||||
LOG.info(
|
||||
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
||||
)
|
||||
|
||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
||||
return None
|
||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
||||
if not ep_size or ep_size <= 1:
|
||||
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
||||
return None
|
||||
world = dist.get_world_size()
|
||||
if world % ep_size != 0:
|
||||
LOG.warning(
|
||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
||||
ep_size,
|
||||
world,
|
||||
)
|
||||
return None
|
||||
if ep_size == world:
|
||||
return dist.group.WORLD
|
||||
|
||||
rank = dist.get_rank()
|
||||
group_start = (rank // ep_size) * ep_size
|
||||
ranks = tuple(range(group_start, group_start + ep_size))
|
||||
if ranks not in self._ep_group_cache:
|
||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
||||
return self._ep_group_cache[ranks]
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
||||
return []
|
||||
if self._shim is None:
|
||||
return []
|
||||
# gather concrete layer modules from handles
|
||||
layers = [h.layer for h in self._handles]
|
||||
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
|
||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
||||
return [cb]
|
||||
@@ -36,6 +36,8 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
|
||||
@@ -168,6 +168,9 @@ def _unwrap_experts_lora(experts_module):
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||
instead of ``gate_up_proj``.
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
@@ -176,6 +179,7 @@ def _unwrap_experts_lora(experts_module):
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
@@ -195,13 +199,15 @@ def _unwrap_experts_lora(experts_module):
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
for attr in ("gate_up_proj", "up_proj"):
|
||||
param = getattr(base_experts, attr, None)
|
||||
if param is not None:
|
||||
num_experts = param.shape[0]
|
||||
break
|
||||
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
@@ -441,10 +447,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
Supports:
|
||||
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -467,7 +475,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||
# ====================================================================
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
@@ -489,6 +497,22 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||
# ====================================================================
|
||||
is_gated = hasattr(experts, "gate_up_proj")
|
||||
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||
# ====================================================================
|
||||
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||
|
||||
expert_input = hidden_states_flat
|
||||
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
@@ -498,7 +522,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
and up_proj_attr in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
@@ -517,11 +541,11 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
up_W = selective_expert_weights(
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
up_proj_attr,
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
).transpose(2, 1)
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
@@ -538,18 +562,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Gate + Up projection
|
||||
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||
# ====================================================================
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear_lora(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -563,9 +587,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
up_out = parallel_linear(
|
||||
expert_input,
|
||||
up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -574,8 +598,14 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
# ====================================================================
|
||||
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||
# ====================================================================
|
||||
if is_gated:
|
||||
gates, h = up_out.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
else:
|
||||
h = experts.act_fn(up_out)
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
@@ -635,6 +665,12 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection back to hidden_size (NemotronH)
|
||||
# ====================================================================
|
||||
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||
expert_output = fc2_latent_proj(expert_output)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
|
||||
@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: bool | None = None
|
||||
liger_rms_norm: bool | None = None
|
||||
liger_rms_norm_gated: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
||||
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
||||
)
|
||||
},
|
||||
)
|
||||
liger_layer_norm: bool | None = None
|
||||
liger_swiglu: bool | None = None
|
||||
liger_glu_activation: bool | None = None
|
||||
|
||||
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> CausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.config.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
||||
|
||||
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
||||
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||
|
||||
|
||||
def lce_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values=None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
) -> MoeCausalLMOutputWithPast:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||
|
||||
Returns:
|
||||
"""
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
load_balancing_loss_func,
|
||||
)
|
||||
|
||||
output_router_logits = (
|
||||
output_router_logits
|
||||
if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_router_logits=output_router_logits,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = None
|
||||
loss = None
|
||||
# if in training mode, don't materialize logits
|
||||
if self.training and (labels is not None):
|
||||
loss = LigerForCausalLMLoss(
|
||||
hidden_states=hidden_states,
|
||||
lm_head_weight=self.lm_head.weight,
|
||||
labels=labels,
|
||||
hidden_size=self.config.hidden_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else: # if in inference mode materialize logits
|
||||
slice_indices = (
|
||||
slice(-logits_to_keep, None)
|
||||
if isinstance(logits_to_keep, int)
|
||||
else logits_to_keep
|
||||
)
|
||||
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||
if labels is not None:
|
||||
loss = self.loss_function(
|
||||
logits,
|
||||
labels,
|
||||
self.vocab_size,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
aux_loss = None
|
||||
if output_router_logits:
|
||||
aux_loss = load_balancing_loss_func(
|
||||
outputs.router_logits,
|
||||
self.num_experts,
|
||||
self.num_experts_per_tok,
|
||||
attention_mask,
|
||||
)
|
||||
if labels is not None:
|
||||
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||
|
||||
return MoeCausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
aux_loss=aux_loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
router_logits=outputs.router_logits,
|
||||
)
|
||||
|
||||
|
||||
def apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy: bool = False,
|
||||
fused_linear_cross_entropy: bool = False,
|
||||
rms_norm: bool = False,
|
||||
rms_norm_gated: bool = False,
|
||||
glu_activation: bool = False,
|
||||
layer_norm: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
||||
|
||||
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||
|
||||
Args:
|
||||
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||
fused_linear_cross_entropy (bool):
|
||||
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
||||
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||
"""
|
||||
|
||||
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||
)
|
||||
|
||||
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
||||
|
||||
if rms_norm:
|
||||
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
||||
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||
super().__init__(
|
||||
dim,
|
||||
eps=eps,
|
||||
offset=1.0,
|
||||
casting_mode="gemma",
|
||||
init_fn="zeros",
|
||||
in_place=False,
|
||||
)
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
||||
|
||||
if rms_norm_gated:
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
||||
|
||||
if glu_activation:
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||
config = deepcopy(config)
|
||||
if intermediate_size is not None:
|
||||
config.intermediate_size = intermediate_size
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
||||
|
||||
if layer_norm:
|
||||
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
if cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
|
||||
if fused_linear_cross_entropy:
|
||||
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
||||
@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5":
|
||||
from axolotl.integrations.liger.models.qwen3_5 import (
|
||||
apply_liger_kernel_to_qwen3_5,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_5_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
||||
apply_liger_kernel_to_qwen3_5_moe,
|
||||
)
|
||||
|
||||
apply_liger_kernel_to_qwen3_5_moe(
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
glu_activation=cfg.liger_glu_activation,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "granitemoe":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||
|
||||
|
||||
147
src/axolotl/kernels/dora.py
Normal file
147
src/axolotl/kernels/dora.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||
|
||||
Fuses the weight norm computation and magnitude scaling to avoid
|
||||
materializing the full [out_features, in_features] combined weight matrix.
|
||||
The B@A product is computed row-by-row inside the kernel.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .quantize import dequantize
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dora_fused_norm_kernel(
|
||||
# Pointers
|
||||
W_ptr, # base weight [out, in] (dequantized, row-major)
|
||||
B_ptr, # LoRA B [out, rank] (row-major)
|
||||
A_ptr, # LoRA A [rank, in] (row-major)
|
||||
mag_ptr, # magnitude vector [out]
|
||||
out_ptr, # output mag_norm_scale [out]
|
||||
# Shapes
|
||||
out_features,
|
||||
in_features,
|
||||
rank,
|
||||
# Scaling
|
||||
lora_scale, # float scaling factor
|
||||
# Block sizes
|
||||
BLOCK_IN: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
||||
):
|
||||
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
||||
|
||||
Each program handles one output row. B[row,:] is loaded once (small),
|
||||
then we tile over in_features computing the dot product with A[:,tile]
|
||||
and accumulating the squared norm.
|
||||
|
||||
This avoids materializing the full [out, in] B@A matrix.
|
||||
"""
|
||||
row = tl.program_id(0)
|
||||
if row >= out_features:
|
||||
return
|
||||
|
||||
# Accumulate squared norm across tiles of in_features
|
||||
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||
|
||||
for start in range(0, in_features, BLOCK_IN):
|
||||
cols = start + tl.arange(0, BLOCK_IN)
|
||||
col_mask = cols < in_features
|
||||
|
||||
# Load W[row, cols]
|
||||
w_vals = tl.load(
|
||||
W_ptr + row * in_features + cols,
|
||||
mask=col_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
# Compute (B[row,:] @ A[:, cols]) for this tile
|
||||
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
||||
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||
for r in tl.static_range(BLOCK_R):
|
||||
# Load scalar B[row, r]
|
||||
b_val = tl.load(
|
||||
B_ptr + row * rank + r,
|
||||
mask=(r < rank),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
# Load vector A[r, cols]
|
||||
a_vals = tl.load(
|
||||
A_ptr + r * in_features + cols,
|
||||
mask=(col_mask & (r < rank)),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
ba_vals += b_val * a_vals
|
||||
|
||||
# Combined: W + s * (B @ A)
|
||||
combined = w_vals + lora_scale * ba_vals
|
||||
|
||||
# Accumulate squared values
|
||||
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
||||
|
||||
# Reduce to scalar norm
|
||||
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
||||
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
||||
|
||||
# Load magnitude and compute scale
|
||||
mag = tl.load(mag_ptr + row).to(tl.float32)
|
||||
scale = mag / norm
|
||||
|
||||
tl.store(out_ptr + row, scale)
|
||||
|
||||
|
||||
def triton_dora_scale(
|
||||
W: torch.Tensor,
|
||||
W_quant,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
s: float,
|
||||
magnitude: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
||||
|
||||
Computes B@A row-by-row inside the kernel, avoiding the full
|
||||
[out_features, in_features] materialization.
|
||||
|
||||
Args:
|
||||
W: base weight [out, in] (possibly quantized)
|
||||
W_quant: quantization state
|
||||
A: LoRA A [rank, in]
|
||||
B: LoRA B [out, rank]
|
||||
s: LoRA scaling factor
|
||||
magnitude: learned magnitude [out]
|
||||
dtype: compute dtype
|
||||
|
||||
Returns:
|
||||
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
||||
"""
|
||||
# Dequantize W to [out, in]
|
||||
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
||||
|
||||
out_features, in_features = W_full.shape
|
||||
rank = A.shape[0]
|
||||
|
||||
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
||||
|
||||
# Block sizes
|
||||
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
||||
BLOCK_R = triton.next_power_of_2(rank)
|
||||
|
||||
_dora_fused_norm_kernel[(out_features,)](
|
||||
W_full,
|
||||
B.contiguous().to(dtype),
|
||||
A.contiguous().to(dtype),
|
||||
magnitude.contiguous(),
|
||||
out,
|
||||
out_features=out_features,
|
||||
in_features=in_features,
|
||||
rank=rank,
|
||||
lora_scale=s,
|
||||
BLOCK_IN=BLOCK_IN,
|
||||
BLOCK_R=BLOCK_R,
|
||||
)
|
||||
|
||||
return out.detach()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,10 @@ def dequantize(
|
||||
# Extract quantization state
|
||||
if not isinstance(quant_state, list):
|
||||
# New style quant_state class
|
||||
# Non-double-quantized models have offset=None and state2=None
|
||||
if quant_state.offset is None or quant_state.state2 is None:
|
||||
# Fall back to bitsandbytes standard dequantize
|
||||
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
||||
absmax = quant_state.absmax.to(target_device)
|
||||
shape = quant_state.shape
|
||||
dtype = quant_state.dtype
|
||||
|
||||
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate Triton kernel.
|
||||
|
||||
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
||||
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
||||
and silu(G) = G * sigmoid(G)
|
||||
|
||||
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
||||
"""
|
||||
|
||||
import math
|
||||
import operator
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from liger_kernel.ops.utils import (
|
||||
calculate_settings,
|
||||
compare_version,
|
||||
ensure_contiguous,
|
||||
torch_to_triton_dtype,
|
||||
)
|
||||
from liger_kernel.utils import is_npu_available
|
||||
|
||||
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||
try:
|
||||
from triton.language.extra.libdevice import rsqrt
|
||||
except ModuleNotFoundError:
|
||||
from triton.language.extra.cuda.libdevice import rsqrt
|
||||
else:
|
||||
from triton.language.math import rsqrt
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_forward_kernel(
|
||||
Y_ptr,
|
||||
Y_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
||||
|
||||
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
||||
|
||||
X_row_dtype = X_row.dtype
|
||||
|
||||
# Cast everything to fp32
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
W_fp32 = W_row.to(tl.float32)
|
||||
|
||||
# RMS norm
|
||||
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||
rstd = rsqrt(mean_sq + eps)
|
||||
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||
|
||||
X_norm = X_fp32 * rstd
|
||||
|
||||
# SiLU gate: silu(G) = G * sigmoid(G)
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# Fused output
|
||||
Y_row = (offset + W_fp32) * X_norm * silu_G
|
||||
|
||||
tl.store(
|
||||
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||
Y_row.to(X_row_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_gated_backward_kernel(
|
||||
dY_ptr,
|
||||
dY_row_stride,
|
||||
dX_ptr,
|
||||
dX_row_stride,
|
||||
dG_ptr,
|
||||
dG_row_stride,
|
||||
X_ptr,
|
||||
X_row_stride,
|
||||
X_dtype: tl.constexpr,
|
||||
G_ptr,
|
||||
G_row_stride,
|
||||
W_ptr,
|
||||
W_row_stride,
|
||||
RSTD_ptr,
|
||||
RSTD_row_stride,
|
||||
dW_ptr,
|
||||
dW_row_stride,
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
||||
|
||||
dW = sum_batch(dY * X_norm * silu(G))
|
||||
dG = dY * (W + offset) * X_norm * silu'(G)
|
||||
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
||||
where m = dY * (W + offset) * silu(G)
|
||||
"""
|
||||
row_block_id = tl.program_id(0).to(tl.int64)
|
||||
row_start = row_block_id * rows_per_program
|
||||
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
||||
W_row = W_row.to(tl.float32) + offset
|
||||
|
||||
for row_idx in range(row_start, row_end):
|
||||
dY_row = tl.load(
|
||||
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
X_row = tl.load(
|
||||
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
G_row = tl.load(
|
||||
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
||||
)
|
||||
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||
|
||||
# Cast to fp32
|
||||
dY_fp32 = dY_row.to(tl.float32)
|
||||
X_fp32 = X_row.to(tl.float32)
|
||||
G_fp32 = G_row.to(tl.float32)
|
||||
|
||||
# Recompute intermediates
|
||||
X_norm = X_fp32 * rstd_row
|
||||
sig_G = tl.sigmoid(G_fp32)
|
||||
silu_G = G_fp32 * sig_G
|
||||
|
||||
# dW: accumulate dY * X_norm * silu(G)
|
||||
dW_acc += dY_fp32 * X_norm * silu_G
|
||||
|
||||
# dG: dY * (W + offset) * X_norm * silu'(G)
|
||||
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
||||
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
||||
tl.store(
|
||||
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
||||
dG_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
||||
m = dY_fp32 * W_row * silu_G
|
||||
dX_row = rstd_row * m
|
||||
dX_row += rstd_row * (
|
||||
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
||||
)
|
||||
tl.store(
|
||||
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||
dX_row.to(X_dtype),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
tl.store(
|
||||
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||
dW_acc,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_gated_forward(X, G, W, eps, offset):
|
||||
shape = X.shape
|
||||
dim = shape[-1]
|
||||
X = X.view(-1, dim)
|
||||
G = G.view(-1, dim)
|
||||
n_rows, n_cols = X.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
|
||||
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
||||
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||
|
||||
assert X.shape[1] == W.shape[0], (
|
||||
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
||||
)
|
||||
assert X.shape == G.shape, (
|
||||
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
||||
)
|
||||
|
||||
_rms_norm_gated_forward_kernel[(n_rows,)](
|
||||
Y,
|
||||
Y.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
offset,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
||||
|
||||
|
||||
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
||||
shape = dY.shape
|
||||
dim = shape[-1]
|
||||
dY = dY.view(-1, dim)
|
||||
n_rows, n_cols = dY.shape
|
||||
|
||||
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||
|
||||
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
||||
dX = torch.empty_like(dY)
|
||||
dG = torch.empty_like(dY)
|
||||
|
||||
rows_per_program = math.ceil(n_rows / sm_count)
|
||||
grid = (sm_count,)
|
||||
|
||||
_rms_norm_gated_backward_kernel[grid](
|
||||
dY,
|
||||
dY.stride(0),
|
||||
dX,
|
||||
dX.stride(0),
|
||||
dG,
|
||||
dG.stride(0),
|
||||
X,
|
||||
X.stride(0),
|
||||
torch_to_triton_dtype[X.dtype],
|
||||
G,
|
||||
G.stride(0),
|
||||
W,
|
||||
W.stride(0),
|
||||
RSTD,
|
||||
RSTD.stride(0),
|
||||
_dW,
|
||||
_dW.stride(0),
|
||||
n_rows,
|
||||
n_cols,
|
||||
offset,
|
||||
rows_per_program,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
dX = dX.view(*shape)
|
||||
dG = dG.view(*shape)
|
||||
dW = _dW.sum(dim=0).to(W.dtype)
|
||||
return dX, dG, dW
|
||||
|
||||
|
||||
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def forward(ctx, X, G, W, eps, offset=0.0):
|
||||
"""
|
||||
X: (B, T, H) or (BxT, H) — input hidden states
|
||||
G: (B, T, H) or (BxT, H) — gate tensor
|
||||
W: (H,) — weight parameter
|
||||
"""
|
||||
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
||||
X, G, W, eps, offset
|
||||
)
|
||||
ctx.offset = offset
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(X, G, W, RSTD)
|
||||
return Y
|
||||
|
||||
@staticmethod
|
||||
@ensure_contiguous
|
||||
def backward(ctx, dY):
|
||||
X, G, W, RSTD = ctx.saved_tensors
|
||||
dX, dG, dW = rms_norm_gated_backward(
|
||||
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
||||
)
|
||||
return dX, dG, dW, None, None
|
||||
|
||||
|
||||
class FusedRMSNormGated(torch.nn.Module):
|
||||
"""
|
||||
Fused RMSNorm + SiLU Gate.
|
||||
|
||||
Computes: Y = W * RMSNorm(X) * silu(G)
|
||||
|
||||
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
||||
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
||||
and forward signature: forward(hidden_states, gate=None)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
if gate is None:
|
||||
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
||||
if hidden_states.device.type != "cuda":
|
||||
raise ValueError(
|
||||
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
||||
)
|
||||
return FusedRMSNormGatedFunction.apply(
|
||||
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
@@ -12,6 +12,7 @@ from torch import nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_embedding,
|
||||
apply_lora_mlp_geglu,
|
||||
apply_lora_mlp_swiglu,
|
||||
apply_lora_o,
|
||||
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
|
||||
active_adapter = model.active_adapter
|
||||
lora_config = model.model.peft_config[active_adapter]
|
||||
|
||||
# Only patch if conditions are met
|
||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
||||
return model
|
||||
# Log what features are active
|
||||
if lora_config.lora_dropout > 0:
|
||||
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
||||
if lora_config.bias != "none":
|
||||
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
||||
if lora_config.use_dora:
|
||||
LOG.info("LoRA kernels: DoRA enabled")
|
||||
|
||||
# This needs to be reset after patching
|
||||
original_level = LOG.getEffectiveLevel()
|
||||
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
|
||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||
]
|
||||
can_patch_qkv = all(
|
||||
hasattr(module, "lora_A")
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
hasattr(module, "lora_A") for module in layer_modules
|
||||
)
|
||||
|
||||
if can_patch_qkv:
|
||||
# Add optimized implementation
|
||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention QKV projections - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||
)
|
||||
if cfg.lora_o_kernel:
|
||||
# Output patching
|
||||
layer_modules = [
|
||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
]
|
||||
can_patch_o = all(
|
||||
hasattr(module, "lora_A")
|
||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||
|
||||
if can_patch_o:
|
||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some attention output projection - requires LoRA "
|
||||
"adapters and no lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some attention output projection - requires LoRA adapters"
|
||||
)
|
||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||
if cfg.lora_mlp_kernel:
|
||||
# MLP patching
|
||||
can_patch_mlp = all(
|
||||
hasattr(proj, "lora_A")
|
||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
||||
for proj in (gate_proj, up_proj, down_proj)
|
||||
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||
)
|
||||
|
||||
if can_patch_mlp:
|
||||
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
|
||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||
else:
|
||||
LOG.warning_once(
|
||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
||||
"lora_magnitude_vector (DoRA)"
|
||||
"Cannot patch some MLP layers - requires LoRA adapters"
|
||||
)
|
||||
|
||||
# Patch embedding layers (model-level, not per-layer)
|
||||
if cfg.lora_embedding_kernel:
|
||||
_patch_embedding_layers(model, cfg)
|
||||
|
||||
LOG.setLevel(original_level)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
||||
"""Patch embedding layers with fused LoRA kernel.
|
||||
|
||||
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
||||
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
||||
"""
|
||||
pretrained_model = model.model
|
||||
patched = 0
|
||||
|
||||
# Find embedding modules - check common locations
|
||||
for attr_path in [
|
||||
("model", "embed_tokens"),
|
||||
("model", "language_model", "embed_tokens"),
|
||||
]:
|
||||
parent = pretrained_model
|
||||
for attr in attr_path:
|
||||
parent = getattr(parent, attr, None)
|
||||
if parent is None:
|
||||
break
|
||||
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
||||
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
||||
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
||||
patched += 1
|
||||
|
||||
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
||||
# when included in target_modules. No special embedding handling needed since
|
||||
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
||||
|
||||
if not patched:
|
||||
LOG.debug("No embedding layers with LoRA found to patch")
|
||||
|
||||
|
||||
class FakeMLP(nn.Module):
|
||||
"""
|
||||
placeholder MLP for triton patching
|
||||
|
||||
@@ -703,6 +703,12 @@ class AxolotlInputConfig(
|
||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||
},
|
||||
)
|
||||
lora_embedding_kernel: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||
},
|
||||
)
|
||||
|
||||
chunked_cross_entropy: bool | None = Field(
|
||||
default=None,
|
||||
@@ -758,44 +764,6 @@ class AxolotlInputConfig(
|
||||
|
||||
llama4_linearized_experts: bool | None = None
|
||||
|
||||
# MoE aux-loss-free (AFB) toggles
|
||||
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
|
||||
},
|
||||
)
|
||||
moe_update_rate: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.",
|
||||
},
|
||||
)
|
||||
moe_update_momentum: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.",
|
||||
},
|
||||
)
|
||||
moe_bias_cap: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
|
||||
},
|
||||
)
|
||||
moe_afb_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
|
||||
},
|
||||
)
|
||||
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
|
||||
},
|
||||
)
|
||||
|
||||
deepspeed: str | dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -874,12 +842,6 @@ class AxolotlInputConfig(
|
||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||
},
|
||||
)
|
||||
expert_parallel_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
||||
},
|
||||
)
|
||||
special_tokens: SpecialTokensConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
@@ -1357,6 +1319,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
or data.get("lora_embedding_kernel")
|
||||
):
|
||||
capabilities = data.get("capabilities")
|
||||
is_fsdp = data.get("fsdp_config") is not None
|
||||
@@ -1404,7 +1367,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("adapter") in ["lora", "qlora"]:
|
||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
||||
kernel_fields = [
|
||||
"lora_mlp_kernel",
|
||||
"lora_qkv_kernel",
|
||||
"lora_o_kernel",
|
||||
"lora_embedding_kernel",
|
||||
]
|
||||
if (
|
||||
any(data.get(k) is not None for k in kernel_fields)
|
||||
or any(data.get(k) for k in unsloth_fields)
|
||||
@@ -1417,9 +1385,38 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
||||
if data.get("lora_dropout") != 0:
|
||||
return data
|
||||
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||
if not has_grouped_mm:
|
||||
is_moe = False
|
||||
model_type = data.get("model_config_type", "")
|
||||
if model_type and "moe" in model_type.lower():
|
||||
is_moe = True
|
||||
if not is_moe:
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
base_model = data.get("base_model")
|
||||
if base_model:
|
||||
auto_cfg = AutoConfig.from_pretrained(
|
||||
base_model, trust_remote_code=False
|
||||
)
|
||||
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||
auto_cfg, "num_experts", None
|
||||
):
|
||||
is_moe = True
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if is_moe:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
@@ -1442,6 +1439,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("lora_o_kernel") is None:
|
||||
data["lora_o_kernel"] = True
|
||||
|
||||
if data.get("lora_embedding_kernel") is None:
|
||||
data["lora_embedding_kernel"] = True
|
||||
|
||||
LOG.warning(
|
||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||
|
||||
@@ -681,15 +681,7 @@ class LoRAValidationMixin:
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_kernels_dora(cls, data):
|
||||
if (
|
||||
data.get("lora_mlp_kernel")
|
||||
or data.get("lora_qkv_kernel")
|
||||
or data.get("lora_o_kernel")
|
||||
) and data.get("peft_use_dora"):
|
||||
raise ValueError(
|
||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
||||
"compatible with DoRA at the moment."
|
||||
)
|
||||
# DoRA is now supported by lora kernels
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1386,14 +1378,6 @@ class ComplexValidationMixin:
|
||||
self.tensor_parallel_size = 1
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_expert_parallel_size(self):
|
||||
if not getattr(self, "expert_parallel_size", None):
|
||||
self.expert_parallel_size = 1
|
||||
elif self.expert_parallel_size < 1:
|
||||
raise ValueError("expert_parallel_size must be >= 1")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_context_parallel_size(self):
|
||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
||||
|
||||
proj.base_layer = base_layer
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||
# quant_state should be None since weight is bf16, not FP8
|
||||
self.assertIsNone(quant_state)
|
||||
|
||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
||||
scale_inv = torch.ones(1)
|
||||
base_layer.weight_scale_inv = scale_inv
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||
self.assertIs(quant_state, scale_inv)
|
||||
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ def mock_proj():
|
||||
def test_get_lora_parameters(mock_proj):
|
||||
"""Tests get_lora_parameters function"""
|
||||
# Test with LoRA enabled
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
|
||||
assert isinstance(W, torch.Tensor)
|
||||
assert W.shape == (128, 64)
|
||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
||||
|
||||
# Test with LoRA disabled
|
||||
mock_proj.disable_adapters = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
# Test with merged state
|
||||
mock_proj.disable_adapters = False
|
||||
mock_proj.merged = True
|
||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
||||
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||
assert A is None and B is None and s is None
|
||||
|
||||
|
||||
@@ -176,24 +176,31 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
@@ -247,24 +254,31 @@ def test_lora_mlp_with_adapters(
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
@@ -334,25 +348,32 @@ def test_lora_qkv(sample_tensors):
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
@@ -366,25 +387,32 @@ def test_lora_qkv(sample_tensors):
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
None,
|
||||
None, # Q
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
True,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
@@ -427,7 +455,9 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
output = LoRA_O.apply(
|
||||
X, None, W, b, None, A, B, scale, None, None
|
||||
) # X_drop, ..., lora_bias, magnitude
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -542,6 +572,7 @@ def test_inplace_operations(sample_tensors, apply_function):
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"training": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
1245
tests/e2e/kernels/test_lora_features.py
Normal file
1245
tests/e2e/kernels/test_lora_features.py
Normal file
File diff suppressed because it is too large
Load Diff
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
||||
|
||||
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
||||
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
||||
with bias, dropout, and DoRA enabled.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.e2e.utils import require_torch_2_7_0
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
def _run_training(temp_dir, cfg):
|
||||
"""Write config and launch multi-GPU training."""
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"axolotl",
|
||||
"train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main-process-port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||
cfg = {
|
||||
"base_model": "Qwen/Qwen3-0.6B",
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
"type": "alpaca",
|
||||
"split": "train[:1%]",
|
||||
},
|
||||
],
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"num_epochs": 1,
|
||||
"max_steps": 3,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-4,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": True,
|
||||
"fsdp_version": 2,
|
||||
"fsdp_config": {
|
||||
"offload_params": False,
|
||||
"cpu_ram_efficient_loading": False,
|
||||
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
||||
"state_dict_type": "FULL_STATE_DICT",
|
||||
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
"reshard_after_forward": True,
|
||||
},
|
||||
# Enable all LoRA kernels
|
||||
"lora_mlp_kernel": True,
|
||||
"lora_qkv_kernel": True,
|
||||
"lora_o_kernel": True,
|
||||
"lora_embedding_kernel": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
cfg.update(overrides)
|
||||
return DictDefault(cfg)
|
||||
|
||||
|
||||
class TestFSDP2LoRAKernels:
|
||||
"""Test LoRA kernels under FSDP2."""
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_basic(self, temp_dir):
|
||||
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dropout(self, temp_dir):
|
||||
"""LoRA kernels + dropout + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dora(self, temp_dir):
|
||||
"""LoRA kernels + DoRA + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
|
||||
@require_torch_2_7_0
|
||||
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
||||
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
||||
cfg = _base_lora_fsdp2_config(
|
||||
temp_dir,
|
||||
peft_use_dora=True,
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
_run_training(temp_dir, cfg)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
||||
|
||||
|
||||
def test_kernel_patch_conditions():
|
||||
"""Test various conditions that should prevent kernel patching."""
|
||||
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
||||
test_configs = [
|
||||
# Dropout prevents patching
|
||||
# Dropout — kernels now support this
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
||||
"lora_dropout": 0.1,
|
||||
"bias": "none",
|
||||
},
|
||||
# Bias prevents patching
|
||||
# Bias — kernels now support this
|
||||
{
|
||||
"peft_type": "LORA",
|
||||
"task_type": "CAUSAL_LM",
|
||||
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
|
||||
model = PeftModelForCausalLM(model, peft_config)
|
||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||
|
||||
# Should not patch
|
||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||
layer = patched_model.model.model.layers[0].mlp
|
||||
|
||||
# Verify no patches applied
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
||||
# Verify patches ARE applied (dropout and bias are now supported)
|
||||
assert (
|
||||
layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||
or layer.forward.__func__ is apply_lora_mlp_geglu
|
||||
)
|
||||
|
||||
|
||||
def test_kernel_config_options():
|
||||
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
# Load config
|
||||
cfg = load_cfg(str(path))
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
# Apply patches — should succeed even with dropout > 0
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||
|
||||
# Verify patch was not applied
|
||||
# Verify patches WERE applied (dropout is now supported by kernels)
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
assert hasattr(self_attn, "apply_qkv")
|
||||
assert hasattr(self_attn, "apply_o")
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
"""
|
||||
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestMoeAuxFree(unittest.TestCase):
|
||||
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
||||
|
||||
@with_temp_dir
|
||||
def test_mixtral_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# Inspect model modules for a patched MoE layer
|
||||
patched = None
|
||||
for m in model.modules():
|
||||
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
|
||||
patched = m
|
||||
break
|
||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
||||
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
||||
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
||||
# ensure counts buffer got reset by callback (best effort)
|
||||
assert torch.all(patched._afb_counts == 0)
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -1,83 +0,0 @@
|
||||
"""
|
||||
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
|
||||
def _last_logged_loss(trainer) -> float | None:
|
||||
# Scan log_history for the most recent entry with a 'loss' key
|
||||
for entry in reversed(trainer.state.log_history):
|
||||
if isinstance(entry, dict) and "loss" in entry:
|
||||
return float(entry["loss"])
|
||||
return None
|
||||
|
||||
|
||||
class TestMoeAuxParity(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
||||
base_cfg = {
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 8,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
"seed": 42,
|
||||
"logging_steps": 1,
|
||||
}
|
||||
|
||||
# Baseline: aux-loss (gshard)
|
||||
cfg0 = DictDefault(dict(base_cfg))
|
||||
cfg0.output_dir = f"{temp_dir}/baseline"
|
||||
cfg0 = validate_config(cfg0)
|
||||
normalize_config(cfg0)
|
||||
# baseline uses default aux-loss routing; no plugin registration
|
||||
dataset_meta0 = load_datasets(cfg=cfg0)
|
||||
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
||||
loss0 = _last_logged_loss(trainer0)
|
||||
assert loss0 is not None
|
||||
|
||||
# Aux-free: plugin + noaux_tc
|
||||
cfg1 = DictDefault(dict(base_cfg))
|
||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
||||
cfg1.plugins = [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
]
|
||||
cfg1.moe_balance_type = "noaux_tc"
|
||||
cfg1.moe_update_rate = 0.01
|
||||
cfg1.moe_update_momentum = 0.9
|
||||
cfg1.moe_bias_cap = 2.0
|
||||
cfg1 = validate_config(cfg1)
|
||||
normalize_config(cfg1)
|
||||
prepare_plugins(cfg1)
|
||||
dataset_meta1 = load_datasets(cfg=cfg1)
|
||||
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
||||
loss1 = _last_logged_loss(trainer1)
|
||||
assert loss1 is not None
|
||||
|
||||
# Assert aux-free loss is within 10% of aux-loss baseline
|
||||
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
|
||||
class TestQwen3MoeAuxFree(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
||||
"flash_attention": False,
|
||||
"sequence_len": 512,
|
||||
"bf16": False,
|
||||
"fp16": False,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 1e-5,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 5,
|
||||
"save_steps": 0,
|
||||
"eval_steps": 0,
|
||||
"save_first_step": False,
|
||||
# Aux-free plugin and toggles
|
||||
"plugins": [
|
||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
||||
],
|
||||
"moe_balance_type": "noaux_tc",
|
||||
"moe_update_rate": 0.01,
|
||||
"moe_update_momentum": 0.9,
|
||||
"moe_bias_cap": 2.0,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
# check that at least one sparse MoE block has been patched
|
||||
found = False
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
|
||||
assert m._afb_patched is True
|
||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
||||
found = True
|
||||
break
|
||||
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
||||
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -12,10 +12,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
try:
|
||||
from tbparse import SummaryReader
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
SummaryReader = None
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -182,14 +179,12 @@ def check_tensorboard(
|
||||
tag: str,
|
||||
lt_val: float,
|
||||
assertion_err: str,
|
||||
rtol: float = 0.02,
|
||||
rtol: float = 0.05,
|
||||
gt_zero: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
helper function to parse and check tensorboard logs
|
||||
"""
|
||||
if SummaryReader is None:
|
||||
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
|
||||
229
tests/kernels/test_rms_norm_gated.py
Normal file
229
tests/kernels/test_rms_norm_gated.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
||||
|
||||
Tests against the eager Qwen3_5RMSNormGated implementation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
pytest.importorskip("triton", reason="triton required for fused kernels")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
||||
|
||||
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||
|
||||
|
||||
class EagerRMSNormGated(torch.nn.Module):
|
||||
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, gate=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = self.weight * hidden_states.to(input_dtype)
|
||||
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def _sync_weights(eager_mod, fused_mod):
|
||||
"""Copy weights from eager to fused module."""
|
||||
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 128, 256),
|
||||
(4, 64, 512),
|
||||
(1, 32, 1024),
|
||||
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
||||
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
||||
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
||||
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
||||
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedForward:
|
||||
def test_output_matches_eager(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
if dtype == torch.float32:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
||||
else:
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
def test_output_shape(self, dtype, shape):
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
y = fused(X, gate=G)
|
||||
assert y.shape == (B, T, H)
|
||||
assert y.dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(2, 32, 256),
|
||||
(2, 16, 512),
|
||||
(2, 16, 2560), # Qwen3.5-4B
|
||||
(1, 8, 4096), # Qwen3.5-9B
|
||||
(1, 8, 5120), # Qwen3.5-27B
|
||||
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
||||
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
||||
],
|
||||
)
|
||||
class TestRMSNormGatedBackward:
|
||||
def test_grad_x(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_gate(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
||||
|
||||
def test_grad_weight(self, dtype, shape):
|
||||
torch.manual_seed(42)
|
||||
B, T, H = shape
|
||||
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
if dtype == torch.float32:
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
else:
|
||||
atol, rtol = 5e-2, 5e-2
|
||||
|
||||
torch.testing.assert_close(
|
||||
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
class TestRMSNormGatedEdgeCases:
|
||||
def test_gate_none_raises(self):
|
||||
fused = FusedRMSNormGated(256).cuda()
|
||||
X = torch.randn(2, 4, 256, device="cuda")
|
||||
with pytest.raises(ValueError, match="requires a gate tensor"):
|
||||
fused(X, gate=None)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
||||
torch.manual_seed(42)
|
||||
H = 512
|
||||
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||
X_ref = X.detach().clone().requires_grad_(True)
|
||||
G_ref = G.detach().clone().requires_grad_(True)
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X_ref, gate=G_ref)
|
||||
y_fused = fused(X, gate=G)
|
||||
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
|
||||
grad_out = torch.randn_like(y_eager)
|
||||
y_eager.backward(grad_out)
|
||||
y_fused.backward(grad_out)
|
||||
|
||||
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
||||
|
||||
def test_random_weight_init(self):
|
||||
"""Test with non-default weight values."""
|
||||
torch.manual_seed(123)
|
||||
H = 256
|
||||
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
# Randomize weights
|
||||
eager.weight.data = torch.randn_like(eager.weight.data)
|
||||
|
||||
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||
_sync_weights(eager, fused)
|
||||
|
||||
y_eager = eager(X, gate=G)
|
||||
y_fused = fused(X, gate=G)
|
||||
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||
@@ -1,267 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from importlib import util as importlib_util
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
||||
|
||||
|
||||
def _cfg(**overrides):
|
||||
defaults = dict(
|
||||
moe_balance_type="noaux_tc",
|
||||
moe_update_rate=0.1,
|
||||
moe_update_momentum=0.9,
|
||||
moe_bias_cap=2.0,
|
||||
moe_afb_warmup_steps=0,
|
||||
moe_bias_sync_group="world",
|
||||
expert_parallel_size=1,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
return SimpleNamespace(**defaults)
|
||||
|
||||
|
||||
def _load_bailing_modules():
|
||||
repo_dir = snapshot_download(
|
||||
repo_id="inclusionAI/Ring-mini-2.0",
|
||||
allow_patterns=[
|
||||
"configuration_bailing_moe_v2.py",
|
||||
"modeling_bailing_moe_v2.py",
|
||||
"__init__.py",
|
||||
],
|
||||
)
|
||||
repo = Path(repo_dir)
|
||||
config_path = repo / "configuration_bailing_moe_v2.py"
|
||||
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
||||
|
||||
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
||||
if config_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[config_name] = module
|
||||
sys.modules["configuration_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
config_module = sys.modules[config_name]
|
||||
|
||||
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
||||
if modeling_name not in sys.modules:
|
||||
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
||||
module = importlib_util.module_from_spec(spec)
|
||||
sys.modules[modeling_name] = module
|
||||
sys.modules["modeling_bailing_moe_v2"] = module
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
modeling_module = sys.modules[modeling_name]
|
||||
|
||||
BailingMoeV2Config = config_module.BailingMoeV2Config
|
||||
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
||||
|
||||
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
||||
|
||||
|
||||
def _build_bailing_model():
|
||||
BailingConfig, BailingBlock = _load_bailing_modules()
|
||||
config = BailingConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
moe_intermediate_size=32,
|
||||
num_experts=4,
|
||||
num_shared_experts=None,
|
||||
num_experts_per_tok=2,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
routed_scaling_factor=1.0,
|
||||
)
|
||||
block = BailingBlock(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, layer):
|
||||
super().__init__()
|
||||
self.block = layer
|
||||
self.config = SimpleNamespace(model_type="bailing_moe")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.block(hidden_states)
|
||||
|
||||
return DummyModel(block), block
|
||||
|
||||
|
||||
def _build_llama4_model():
|
||||
from transformers import Llama4TextConfig
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
||||
|
||||
config = Llama4TextConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
num_experts_per_tok=2,
|
||||
)
|
||||
layer = Llama4TextMoe(config)
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="llama4")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _build_mixtral_model():
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
config = MixtralConfig(
|
||||
hidden_size=16,
|
||||
intermediate_size=32,
|
||||
num_local_experts=4,
|
||||
num_experts_per_tok=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
layer = MixtralSparseMoeBlock(config)
|
||||
layer.config = config
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self, moe_layer):
|
||||
super().__init__()
|
||||
self.moe = moe_layer
|
||||
self.config = SimpleNamespace(model_type="mixtral")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
return self.moe(hidden_states)
|
||||
|
||||
return DummyModel(layer), layer
|
||||
|
||||
|
||||
def _run_callback(plugin, cfg):
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
assert callbacks, "expected aux-free callback to be registered"
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
|
||||
|
||||
class TestAuxFreeAdapters(unittest.TestCase):
|
||||
def test_bailing_adapter_updates_counts_and_bias(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(block, "_afb_bias"))
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
||||
|
||||
def test_llama4_adapter_biases_router_selection(self):
|
||||
model, layer = _build_llama4_model()
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
||||
hidden = torch.randn(2, 4, layer.hidden_dim)
|
||||
layer(hidden)
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
|
||||
_run_callback(plugin, cfg)
|
||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
||||
|
||||
def test_bias_warmup_respected(self):
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_afb_warmup_steps=2)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
def _step():
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
|
||||
# Warmup steps should leave bias untouched.
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
|
||||
_step()
|
||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
||||
|
||||
# Third step exceeds warmup -> bias should update.
|
||||
_step()
|
||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
||||
|
||||
def test_mixtral_adapter_respects_native_forward(self):
|
||||
model, layer = _build_mixtral_model()
|
||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
||||
|
||||
hidden_dim = layer.config.hidden_size
|
||||
hidden = torch.randn(2, 3, hidden_dim)
|
||||
baseline_out, baseline_logits = layer(hidden.clone())
|
||||
|
||||
cfg = _cfg()
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
patched_out, patched_logits = layer(hidden.clone())
|
||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
||||
_run_callback(plugin, cfg)
|
||||
|
||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
model, block = _build_bailing_model()
|
||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
||||
plugin = AuxFreeMoEPlugin()
|
||||
plugin.post_model_build(cfg, model)
|
||||
|
||||
self.assertIsNotNone(plugin._shim)
|
||||
self.assertIsNone(plugin._shim.ep_group)
|
||||
|
||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
||||
self.assertTrue(callbacks)
|
||||
callback = callbacks[0]
|
||||
dummy = SimpleNamespace()
|
||||
|
||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
||||
tmp_init.close()
|
||||
init_method = f"file://{tmp_init.name}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
||||
try:
|
||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
||||
block(hidden)
|
||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
os.unlink(tmp_init.name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
|
||||
result = validate_config(valid_config)
|
||||
assert result["adapter"] == "lora"
|
||||
|
||||
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
||||
invalid_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"peft_use_dora": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
validate_config(invalid_config)
|
||||
# DoRA is now compatible with lora kernels
|
||||
dora_kernel_config = DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_mlp_kernel": True,
|
||||
"peft_use_dora": True,
|
||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"learning_rate": 1e-5,
|
||||
"base_model": "dummy_model",
|
||||
}
|
||||
)
|
||||
result = validate_config(dora_kernel_config)
|
||||
assert result["lora_mlp_kernel"] is True
|
||||
assert result["peft_use_dora"] is True
|
||||
|
||||
def test_qlora_4bit_validation(self):
|
||||
"""Test QLoRA 4-bit configuration validation"""
|
||||
|
||||
@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
|
||||
|
||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||
mock_layer.lora_B["default"].bias = None
|
||||
|
||||
# Required by get_lora_parameters for dropout/DoRA extraction
|
||||
mock_layer.lora_dropout = {}
|
||||
mock_layer.lora_magnitude_vector = None
|
||||
else:
|
||||
mock_layer.weight = base_layer.weight
|
||||
mock_layer.bias = base_layer.bias
|
||||
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
|
||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
|
||||
"""Test that LoRA parameters are None when adapters are merged."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
|
||||
"""Test parameter behavior when no adapters are present."""
|
||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
# Base parameters should be returned
|
||||
assert W is not None
|
||||
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
# All parameters should be returned
|
||||
assert W is not None
|
||||
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
# Check shape consistency
|
||||
assert W.shape == (512, 256)
|
||||
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
|
||||
has_adapters=True, adapters_disabled=False, merged=False
|
||||
)
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
assert W.dtype == self.dtype
|
||||
assert b.dtype == self.dtype
|
||||
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
|
||||
quant_state_mock = Mock()
|
||||
layer.base_layer.weight.quant_state = quant_state_mock
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
assert quant_state == quant_state_mock
|
||||
|
||||
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
|
||||
|
||||
layer.active_adapters = ["adapter2"]
|
||||
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||
|
||||
assert s == 0.2
|
||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
|
||||
model = get_peft_model(base_model, lora_config)
|
||||
lora_layer = model.base_model.model.linear
|
||||
# Test with adapters enabled
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||
assert A is not None
|
||||
assert B is not None
|
||||
assert s is not None
|
||||
# Test with adapters disabled
|
||||
model.disable_adapter_layers()
|
||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
||||
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||
assert A is None
|
||||
assert B is None
|
||||
assert s is None
|
||||
|
||||
Reference in New Issue
Block a user