Compare commits

..

5 Commits

Author SHA1 Message Date
Wing Lian
db6af43f3b chore: lint 2026-03-23 04:54:00 +00:00
Wing Lian
35d06c8087 add textui 2026-03-23 04:54:00 +00:00
Wing Lian
0e583efeaa increase rtol, codecov informational only, don't silently fail errors w curl (#3534) [skip ci] 2026-03-22 13:54:03 -04:00
Wing Lian
b3289fd190 feat: LoRA kernel support for bias, dropout, dora, embeddings (#3528) [skip ci]
* feat: LoRA kernel support for bias, dropout, dora, embeddings

* chore: lint

* chore: lint

* address PR feedback, add regression tests, add fsdp2 tests for lora kernels

* update tests for new sigs

* update tests now that bias and dropout are supported
2026-03-22 13:53:19 -04:00
Wing Lian
a67392c427 liger support for qwen 3.5 and fused rmsnorm+gated (#3531) [skip ci]
* liger support for qwen 3.5 and fused rmsnorm+gated

* support for qwen 3.5 moe

* fix version ref

* fixups for PR code review
2026-03-22 13:19:21 -04:00
54 changed files with 5656 additions and 1714 deletions

View File

@@ -3,7 +3,8 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"

View File

@@ -37,6 +37,7 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@@ -101,6 +102,7 @@ def train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
tui: bool = False,
**kwargs,
):
"""
@@ -118,6 +120,10 @@ def train(
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Handle --tui flag: set env var so subprocess workers pick it up
if tui:
os.environ["AXOLOTL_TUI"] = "1"
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher

View File

@@ -2,6 +2,7 @@
import gc
import os
import queue
from pathlib import Path
from typing import Union
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# Start TUI early (before data loading) so it captures preprocessing events
tui_renderer = None
tui_queue: queue.Queue | None = None
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
if is_rank_0:
from axolotl.train import _is_tui_enabled
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
if _is_tui_enabled(cfg):
import queue as _queue
del model, tokenizer, trainer
from axolotl.train import _get_tui_config
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
gc.collect()
tui_config_dict = _get_tui_config(cfg)
tui_config = (
TUIConfig(**tui_config_dict)
if isinstance(tui_config_dict, dict)
else tui_config_dict
)
tui_queue = _queue.Queue(maxsize=4096)
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
# Send initial run info
model_name = cfg.base_model or ""
training_mode = str(cfg.rl) if cfg.rl else "sft"
world_size = int(os.environ.get("WORLD_SIZE", 1))
try:
tui_queue.put_nowait(
{
"type": "run_info",
"model_name": model_name,
"training_mode": training_mode,
"world_size": world_size,
}
)
except _queue.Full:
pass
tui_renderer.start()
# Attach logging handler early
import logging
from axolotl.tui.callback import _TUILogHandler
_early_log_handler = _TUILogHandler(
tui_queue, min_level=tui_config.log_level
)
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to BOTH root and axolotl loggers because axolotl logger
# has propagate=False so root handler never sees axolotl.* messages
root_logger = logging.getLogger()
root_logger.addHandler(_early_log_handler)
axolotl_logger = logging.getLogger("axolotl")
axolotl_logger.addHandler(_early_log_handler)
# Stash refs on cfg so train() can reuse the renderer
cfg._tui_renderer = tui_renderer
cfg._tui_queue = tui_queue
cfg._tui_early_log_handler = _early_log_handler
try:
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
finally:
# If the TUI renderer started early but train() didn't get to stop it
# (e.g., error during data loading), clean up here
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
try:
if tui_queue is not None:
tui_queue.put_nowait({"type": "done"})
except queue.Full:
pass
tui_renderer.stop()
# Remove early log handler from both root and axolotl loggers
if hasattr(cfg, "_tui_early_log_handler"):
import logging
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

View File

@@ -1,40 +0,0 @@
# Aux-Loss-Free MoE Router Plugin
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
Summary
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
Enable
- Add the plugin to your YAML, then set the aux-free toggle:
plugins:
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
moe_balance_type: noaux_tc
moe_update_rate: 0.01 # default if unset
moe_update_momentum: 0.9 # default if unset
moe_bias_cap: 2.0 # default if unset
moe_afb_warmup_steps: 100 # optional
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
Config keys
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
- moe_update_rate: bias update rate (gamma). Default: 0.01.
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
Compatibility
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
Notes
- If you also enable Ligers aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
- Telemetry: future updates will log per-expert loads and bias magnitudes.

View File

@@ -1,2 +0,0 @@
"""Aux-loss-free (AFB) MoE router integration package."""

View File

@@ -1,317 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from torch import nn
import torch.nn.functional as F
from axolotl.utils.logging import get_logger
from .core import AuxFreeShim
LOG = get_logger(__name__)
@dataclass
class LayerHandle:
layer: nn.Module
layer_idx: int
num_experts: int
top_k: int
class BaseMoEAdapter:
"""Base adapter that discovers MoE layers and wraps their forward.
Concrete adapters should implement discovery and per-layer attribute extraction.
"""
family: str = "generic"
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
return False
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
return []
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
return int(getattr(moe_layer, "num_experts"))
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
# Best-effort: zero router aux loss coef if present
if hasattr(model_or_layer, "router_aux_loss_coef"):
try:
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
except Exception: # pragma: no cover - non-critical
pass
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
device = next(moe_layer.parameters(), torch.tensor(0)).device
if not hasattr(moe_layer, "_afb_bias"):
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_counts"):
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
if not hasattr(moe_layer, "_afb_ema"):
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
shim.register_layer_buffers(handle.layer_idx, moe_layer)
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
"""Attach per-layer buffers and mark as aux-free enabled."""
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_forward_with_aux_free(moe_layer)
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
"""Replace the layer's forward with an aux-free gating version.
Assumes the layer exposes attributes:
- gate: linear router projecting hidden to num_experts
- num_experts: int
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
"""
if getattr(moe_layer, "_afb_patched", False):
return
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
return
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
# hidden_states: (B, T, H)
bsz, seqlen, hdim = hidden_states.shape
hs = hidden_states.view(-1, hdim)
logits = self.gate(hs)
# selection uses biased logits; weights from unbiased logits
bias = getattr(self, "_afb_bias")
top_k = int(getattr(self, "_afb_top_k", 2))
biased = logits + bias # broadcast over tokens
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1)
weights = weights.to(hs.dtype)
# accumulate counts for bias update callback
flat_idx = topk_idx.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
# dispatch tokens to experts
hs_rep = hs.repeat_interleave(top_k, dim=0)
y = torch.empty_like(hs_rep)
for eid in range(int(self.num_experts)):
mask = flat_idx == eid
if mask.any():
y[mask] = self.experts[eid](hs_rep[mask])
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
out = y.view(bsz, seqlen, hdim)
return (out, logits)
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class MixtralAdapter(BaseMoEAdapter):
family = "mixtral"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_mixtral_forward(moe_layer, shim)
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock"):
yield m
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
if getattr(moe_layer, "_afb_patched", False):
return
shim_ref = shim
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and getattr(self, "jitter_noise", 0) > 0:
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
flat_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(flat_states)
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
top_k = int(getattr(self, "_afb_top_k", self.top_k))
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
routing_weights = routing_weights.to(flat_states.dtype)
flat_idx = selected_experts.reshape(-1)
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=flat_states.dtype,
device=flat_states.device,
)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
setattr(moe_layer, "_afb_patched", True)
class Qwen3Adapter(MixtralAdapter):
family = "qwen3_moe"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
class BailingAdapter(BaseMoEAdapter):
family = "bailing_moe"
def matches(self, model: nn.Module) -> bool:
model_type = getattr(getattr(model, "config", object()), "model_type", "")
return model_type in ("bailing_moe", "bailing_moe_v2")
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
yield m
def get_num_experts(self, moe_layer: nn.Module) -> int:
if hasattr(moe_layer, "num_experts"):
return int(getattr(moe_layer, "num_experts"))
cfg = getattr(moe_layer, "config", None)
return int(getattr(cfg, "num_experts"))
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_bailing_gate(moe_layer)
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
gate = getattr(moe_layer, "gate", None)
if gate is None:
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
return
if getattr(gate, "_afb_patched", False):
return
def afb_gate_forward(self, hidden_states: torch.Tensor):
flat = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(flat.float(), self.weight.float())
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
bias = getattr(moe_layer, "_afb_bias")
biased_scores = scores_unbiased + bias
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
weights = torch.gather(scores_unbiased, 1, topk_idx)
if self.top_k > 1:
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
weights = weights / denom
weights = weights * self.routed_scaling_factor
flat_topk = topk_idx.reshape(-1)
counts = torch.bincount(flat_topk, minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return topk_idx, weights.to(hidden_states.dtype), logits
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
setattr(gate, "_afb_patched", True)
class Llama4Adapter(BaseMoEAdapter):
family = "llama4"
def matches(self, model: nn.Module) -> bool:
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
for m in model.modules():
if m.__class__.__name__ == "Llama4TextMoe":
yield m
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
self._register_aux_buffers(moe_layer, handle, shim)
self._patch_llama4_router(moe_layer)
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
router = getattr(moe_layer, "router", None)
if router is None:
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
return
if getattr(router, "_afb_patched", False):
return
def afb_router_forward(self, hidden_states: torch.Tensor):
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
router_logits = F.linear(flat, self.weight, self.bias)
bias = getattr(moe_layer, "_afb_bias")
biased_logits = router_logits + bias
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
unbiased_top = torch.gather(router_logits, 1, router_indices)
router_scores = torch.full_like(router_logits, float("-inf"))
router_scores.scatter_(1, router_indices, unbiased_top)
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
return router_scores, router_logits
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
setattr(router, "_afb_patched", True)
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
Returns a list of layer handles for later routing patching and updates.
"""
handles: list[LayerHandle] = []
adapter: Optional[BaseMoEAdapter] = None
for a in adapters:
if a.matches(model):
adapter = a
break
if adapter is None:
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
return handles
# disable aux loss at model level if possible
adapter.disable_aux_loss(getattr(model, "config", model))
idx = 0
for layer in adapter.find_moe_layers(model):
try:
top_k = adapter.get_top_k(layer)
nE = adapter.get_num_experts(layer)
except Exception:
continue
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
adapter.prepare(layer, handle, shim)
handles.append(handle)
idx += 1
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
return handles

View File

@@ -1,150 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass
class AuxFreeConfig:
rate: float = 0.01
momentum: float = 0.9
bias_cap: float = 2.0
warmup_steps: int = 0
sync_group: str = "world" # or "ep"
class AuxFreeState:
"""Holds per-layer bias and EMA load buffers."""
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
self.cfg = cfg
self.steps = 0
class AuxFreeShim:
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
def __init__(
self,
state: AuxFreeState,
ep_group: Optional[dist.ProcessGroup] = None,
ep_size: Optional[int] = None,
):
self.state = state
self.ep_group = ep_group
self._ep_size = ep_size
self._ep_group_pending = (
self.state.cfg.sync_group == "ep" and self.ep_group is None
)
self._layer_modules: dict[int, torch.nn.Module] = {}
@torch.no_grad()
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_bias"):
b = getattr(module, "_afb_bias")
else:
b = self.state.bias[layer_idx]
biased = logits + b # bias is a buffer
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
chosen_logits = torch.gather(logits, -1, topk_idx)
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
return topk_idx, weights
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
"""Bind model buffers so shim updates stay in sync with patched layers."""
self._layer_modules[layer_idx] = module
bias = getattr(module, "_afb_bias")
ema = getattr(module, "_afb_ema")
# Keep state views pointing to the same tensors to avoid drift.
if layer_idx < len(self.state.bias):
self.state.bias[layer_idx] = bias
if layer_idx < len(self.state.ema_load):
self.state.ema_load[layer_idx] = ema
def begin_step(self) -> None:
"""Call once per optimizer step before per-layer updates."""
self.state.steps += 1
@torch.no_grad()
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
self._maybe_init_ep_group()
if not dist.is_available() or not dist.is_initialized():
return counts
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
return counts
@torch.no_grad()
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
cfg = self.state.cfg
if self.state.steps <= cfg.warmup_steps:
return
nE = step_counts.numel()
if tokens_seen <= 0:
return
module = self._layer_modules.get(layer_idx)
if module is not None and hasattr(module, "_afb_ema"):
ema = getattr(module, "_afb_ema")
bias = getattr(module, "_afb_bias")
else:
ema = self.state.ema_load[layer_idx]
bias = self.state.bias[layer_idx]
counts = step_counts.to(ema.device)
freq = counts.float() / float(tokens_seen)
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
target = 1.0 / float(nE)
delta = cfg.rate * (target - ema)
# optional mean-centering to keep sum(bias) ~ 0
delta = delta - delta.mean()
bias.add_(delta)
if cfg.bias_cap is not None and cfg.bias_cap > 0:
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
def _maybe_init_ep_group(self) -> None:
if not self._ep_group_pending:
return
if not dist.is_available() or not dist.is_initialized():
return
ep_size = self._ep_size
if not ep_size or ep_size <= 1:
LOG.warning(
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
ep_size,
world,
)
self.ep_group = dist.group.WORLD
self._ep_group_pending = False
return
if ep_size == world:
self.ep_group = dist.group.WORLD
else:
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
self.ep_group = dist.new_group(ranks)
LOG.info(
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
ep_size,
dist.get_world_size(),
)
self._ep_group_pending = False

View File

@@ -1,175 +0,0 @@
"""Aux-loss-free MoE Router Plugin for Axolotl.
This plugin wires an aux-free gating option into compatible MoE models using
unbiased logits for mixture weights and per-expert biases for top-k selection.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.distributed as dist
from transformers.trainer_callback import TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .adapters import (
BailingAdapter,
BaseMoEAdapter,
Llama4Adapter,
MixtralAdapter,
Qwen3Adapter,
discover_and_prepare_layers,
)
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
LOG = get_logger(__name__)
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
"""Post-step callback to update aux-free biases from accumulated expert counts.
Note: The current revision expects per-layer counts to be accumulated on each
MoE layer as a buffer named `_afb_counts` during forward (to be added with
routing patches in a follow-up).
"""
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
self.shim = shim
self.layer_modules = layer_modules
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
# Iterate prepared MoE layers and apply the bias update rule.
self.shim.begin_step()
for layer in self.layer_modules:
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
continue
counts = getattr(layer, "_afb_counts")
if counts is None:
continue
counts = self.shim.all_reduce_counts(counts)
layer_idx = getattr(layer, "_afb_layer_idx", None)
if layer_idx is None:
counts.zero_()
continue
bias = getattr(layer, "_afb_bias")
counts_for_update = counts.to(bias.device)
tokens_seen = int(counts_for_update.sum().item())
# local layer-state EMA and bias update
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
# reset step counts
counts.zero_()
return control
class AuxFreeMoEPlugin(BasePlugin):
"""Plugin that enables aux-loss-free routing when configured."""
def __init__(self):
super().__init__()
self._handles: list = []
self._shim: Optional[AuxFreeShim] = None
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
def post_model_build(self, cfg, model):
# Enable only when explicitly requested
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return
# Be conservative — skip known native aux-free families
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
"deepseek_v3",
"glm4_moe",
)
if native_auxfree:
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
return
# Build aux-free state and shim
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
momentum = (
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
)
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
af_cfg = AuxFreeConfig(
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
)
# Discover layers to count the number and experts for state sizing
adapters: list[BaseMoEAdapter] = [
MixtralAdapter(),
Qwen3Adapter(),
BailingAdapter(),
Llama4Adapter(),
]
# For initial state sizing, we conservatively assume the first discovered layer defines nE
n_layers = 0
n_experts = None
for m in model.modules():
n_layers += 1 # upper bound — we will re-use bias slots sparsely
device = next(model.parameters(), torch.tensor(0)).device
if n_layers <= 0:
n_layers = 1
if n_experts is None:
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
n_experts = 2
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
ep_size = getattr(cfg, "expert_parallel_size", None)
ep_group = None
if sync_group == "ep":
if dist.is_available() and dist.is_initialized():
ep_group = self._resolve_ep_group(cfg)
else:
LOG.info(
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
)
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
# Discover and prepare layers (attach per-layer buffers)
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
LOG.info(
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
)
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
if not dist.is_available() or not dist.is_initialized():
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
return None
ep_size = getattr(cfg, "expert_parallel_size", None)
if not ep_size or ep_size <= 1:
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
return None
world = dist.get_world_size()
if world % ep_size != 0:
LOG.warning(
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
ep_size,
world,
)
return None
if ep_size == world:
return dist.group.WORLD
rank = dist.get_rank()
group_start = (rank // ep_size) * ep_size
ranks = tuple(range(group_start, group_start + ep_size))
if ranks not in self._ep_group_cache:
self._ep_group_cache[ranks] = dist.new_group(ranks)
return self._ep_group_cache[ranks]
def add_callbacks_post_trainer(self, cfg, trainer):
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
return []
if self._shim is None:
return []
# gather concrete layer modules from handles
layers = [h.layer for h in self._handles]
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
LOG.info("AuxFreeMoE: registering post-step bias update callback")
return [cb]

View File

@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
liger_rope: bool | None = None
liger_rms_norm: bool | None = None
liger_rms_norm_gated: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
)
},
)
liger_layer_norm: bool | None = None
liger_swiglu: bool | None = None
liger_glu_activation: bool | None = None

View File

@@ -0,0 +1,175 @@
"""
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_liger_kernel_to_qwen3_5(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
if rms_norm:
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward

View File

@@ -0,0 +1,198 @@
"""
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
"""
import sys
from copy import deepcopy
from typing import Optional, Union
import torch
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
"""
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
load_balancing_loss_func,
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = None
loss = None
# if in training mode, don't materialize logits
if self.training and (labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference mode materialize logits
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits,
labels,
self.vocab_size,
**kwargs,
)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
def apply_liger_kernel_to_qwen3_5_moe(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = False,
rms_norm: bool = False,
rms_norm_gated: bool = False,
glu_activation: bool = False,
layer_norm: bool = False,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
Args:
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is False.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
"""
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
if rms_norm:
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
def __init__(self, dim, eps=1e-6, **kwargs):
super().__init__(
dim,
eps=eps,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
if rms_norm_gated:
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
if glu_activation:
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
config = deepcopy(config)
if intermediate_size is not None:
config.intermediate_size = intermediate_size
return LigerSwiGLUMLP(config, **kwargs)
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_mod.nn.LayerNorm = LigerLayerNorm
if cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward

View File

@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5":
from axolotl.integrations.liger.models.qwen3_5 import (
apply_liger_kernel_to_qwen3_5,
)
apply_liger_kernel_to_qwen3_5(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_moe":
from axolotl.integrations.liger.models.qwen3_moe import (
apply_liger_kernel_to_qwen3_moe,
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "qwen3_5_moe":
from axolotl.integrations.liger.models.qwen3_5_moe import (
apply_liger_kernel_to_qwen3_5_moe,
)
apply_liger_kernel_to_qwen3_5_moe(
cross_entropy=cfg.liger_cross_entropy,
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
glu_activation=cfg.liger_glu_activation,
rms_norm=cfg.liger_rms_norm,
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type == "granitemoe":
from liger_kernel.transformers import apply_liger_kernel_to_granite

147
src/axolotl/kernels/dora.py Normal file
View File

@@ -0,0 +1,147 @@
"""
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
Fuses the weight norm computation and magnitude scaling to avoid
materializing the full [out_features, in_features] combined weight matrix.
The B@A product is computed row-by-row inside the kernel.
"""
import torch
import triton
import triton.language as tl
from .quantize import dequantize
@triton.jit
def _dora_fused_norm_kernel(
# Pointers
W_ptr, # base weight [out, in] (dequantized, row-major)
B_ptr, # LoRA B [out, rank] (row-major)
A_ptr, # LoRA A [rank, in] (row-major)
mag_ptr, # magnitude vector [out]
out_ptr, # output mag_norm_scale [out]
# Shapes
out_features,
in_features,
rank,
# Scaling
lora_scale, # float scaling factor
# Block sizes
BLOCK_IN: tl.constexpr,
BLOCK_R: tl.constexpr, # >= rank, power of 2
):
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
Each program handles one output row. B[row,:] is loaded once (small),
then we tile over in_features computing the dot product with A[:,tile]
and accumulating the squared norm.
This avoids materializing the full [out, in] B@A matrix.
"""
row = tl.program_id(0)
if row >= out_features:
return
# Accumulate squared norm across tiles of in_features
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
for start in range(0, in_features, BLOCK_IN):
cols = start + tl.arange(0, BLOCK_IN)
col_mask = cols < in_features
# Load W[row, cols]
w_vals = tl.load(
W_ptr + row * in_features + cols,
mask=col_mask,
other=0.0,
).to(tl.float32)
# Compute (B[row,:] @ A[:, cols]) for this tile
# Load B[row, r] as scalar and A[r, cols] as vector for each r
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
for r in tl.static_range(BLOCK_R):
# Load scalar B[row, r]
b_val = tl.load(
B_ptr + row * rank + r,
mask=(r < rank),
other=0.0,
).to(tl.float32)
# Load vector A[r, cols]
a_vals = tl.load(
A_ptr + r * in_features + cols,
mask=(col_mask & (r < rank)),
other=0.0,
).to(tl.float32)
ba_vals += b_val * a_vals
# Combined: W + s * (B @ A)
combined = w_vals + lora_scale * ba_vals
# Accumulate squared values
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
# Reduce to scalar norm
norm_sq = tl.sum(norm_sq_acc, axis=0)
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
# Load magnitude and compute scale
mag = tl.load(mag_ptr + row).to(tl.float32)
scale = mag / norm
tl.store(out_ptr + row, scale)
def triton_dora_scale(
W: torch.Tensor,
W_quant,
A: torch.Tensor,
B: torch.Tensor,
s: float,
magnitude: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Compute DoRA mag_norm_scale using fused Triton kernel.
Computes B@A row-by-row inside the kernel, avoiding the full
[out_features, in_features] materialization.
Args:
W: base weight [out, in] (possibly quantized)
W_quant: quantization state
A: LoRA A [rank, in]
B: LoRA B [out, rank]
s: LoRA scaling factor
magnitude: learned magnitude [out]
dtype: compute dtype
Returns:
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
"""
# Dequantize W to [out, in]
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
out_features, in_features = W_full.shape
rank = A.shape[0]
out = torch.empty(out_features, dtype=dtype, device=W.device)
# Block sizes
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
BLOCK_R = triton.next_power_of_2(rank)
_dora_fused_norm_kernel[(out_features,)](
W_full,
B.contiguous().to(dtype),
A.contiguous().to(dtype),
magnitude.contiguous(),
out,
out_features=out_features,
in_features=in_features,
rank=rank,
lora_scale=s,
BLOCK_IN=BLOCK_IN,
BLOCK_R=BLOCK_R,
)
return out.detach()

File diff suppressed because it is too large Load Diff

View File

@@ -105,6 +105,10 @@ def dequantize(
# Extract quantization state
if not isinstance(quant_state, list):
# New style quant_state class
# Non-double-quantized models have offset=None and state2=None
if quant_state.offset is None or quant_state.state2 is None:
# Fall back to bitsandbytes standard dequantize
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
absmax = quant_state.absmax.to(target_device)
shape = quant_state.shape
dtype = quant_state.dtype

View File

@@ -0,0 +1,333 @@
"""
Fused RMSNorm + SiLU Gate Triton kernel.
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
and silu(G) = G * sigmoid(G)
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
"""
import math
import operator
import torch
import triton
import triton.language as tl
from liger_kernel.ops.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)
from liger_kernel.utils import is_npu_available
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
try:
from triton.language.extra.libdevice import rsqrt
except ModuleNotFoundError:
from triton.language.extra.cuda.libdevice import rsqrt
else:
from triton.language.math import rsqrt
@triton.jit
def _rms_norm_gated_forward_kernel(
Y_ptr,
Y_row_stride,
X_ptr,
X_row_stride,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
n_cols,
eps,
offset,
BLOCK_SIZE: tl.constexpr,
):
"""
Y = (W + offset) * (X / RMS(X)) * silu(G)
All computation done in fp32 (Gemma-style), result cast to input dtype.
"""
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
# Cast everything to fp32
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
W_fp32 = W_row.to(tl.float32)
# RMS norm
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
rstd = rsqrt(mean_sq + eps)
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
X_norm = X_fp32 * rstd
# SiLU gate: silu(G) = G * sigmoid(G)
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# Fused output
Y_row = (offset + W_fp32) * X_norm * silu_G
tl.store(
Y_ptr + row_idx * Y_row_stride + col_offsets,
Y_row.to(X_row_dtype),
mask=mask,
)
@triton.jit
def _rms_norm_gated_backward_kernel(
dY_ptr,
dY_row_stride,
dX_ptr,
dX_row_stride,
dG_ptr,
dG_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
G_ptr,
G_row_stride,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE: tl.constexpr,
):
"""
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
dW = sum_batch(dY * X_norm * silu(G))
dG = dY * (W + offset) * X_norm * silu'(G)
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
where m = dY * (W + offset) * silu(G)
"""
row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row.to(tl.float32) + offset
for row_idx in range(row_start, row_end):
dY_row = tl.load(
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
)
X_row = tl.load(
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
)
G_row = tl.load(
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
)
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
# Cast to fp32
dY_fp32 = dY_row.to(tl.float32)
X_fp32 = X_row.to(tl.float32)
G_fp32 = G_row.to(tl.float32)
# Recompute intermediates
X_norm = X_fp32 * rstd_row
sig_G = tl.sigmoid(G_fp32)
silu_G = G_fp32 * sig_G
# dW: accumulate dY * X_norm * silu(G)
dW_acc += dY_fp32 * X_norm * silu_G
# dG: dY * (W + offset) * X_norm * silu'(G)
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
tl.store(
dG_ptr + row_idx * dG_row_stride + col_offsets,
dG_row.to(X_dtype),
mask=mask,
)
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
m = dY_fp32 * W_row * silu_G
dX_row = rstd_row * m
dX_row += rstd_row * (
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
)
tl.store(
dX_ptr + row_idx * dX_row_stride + col_offsets,
dX_row.to(X_dtype),
mask=mask,
)
tl.store(
dW_ptr + row_block_id * dW_row_stride + col_offsets,
dW_acc,
mask=mask,
)
def rms_norm_gated_forward(X, G, W, eps, offset):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
G = G.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
assert X.shape[1] == W.shape[0], (
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
)
assert X.shape == G.shape, (
f"X and G must have same shape, got {X.shape} and {G.shape}"
)
_rms_norm_gated_forward_kernel[(n_rows,)](
Y,
Y.stride(0),
X,
X.stride(0),
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
dX = torch.empty_like(dY)
dG = torch.empty_like(dY)
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
_rms_norm_gated_backward_kernel[grid](
dY,
dY.stride(0),
dX,
dX.stride(0),
dG,
dG.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
G,
G.stride(0),
W,
W.stride(0),
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dX.view(*shape)
dG = dG.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dG, dW
class FusedRMSNormGatedFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, X, G, W, eps, offset=0.0):
"""
X: (B, T, H) or (BxT, H) — input hidden states
G: (B, T, H) or (BxT, H) — gate tensor
W: (H,) — weight parameter
"""
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
X, G, W, eps, offset
)
ctx.offset = offset
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, G, W, RSTD)
return Y
@staticmethod
@ensure_contiguous
def backward(ctx, dY):
X, G, W, RSTD = ctx.saved_tensors
dX, dG, dW = rms_norm_gated_backward(
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
)
return dX, dG, dW, None, None
class FusedRMSNormGated(torch.nn.Module):
"""
Fused RMSNorm + SiLU Gate.
Computes: Y = W * RMSNorm(X) * silu(G)
Drop-in replacement for Qwen3_5RMSNormGated with matching
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
and forward signature: forward(hidden_states, gate=None)
"""
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.offset = offset
def forward(self, hidden_states, gate=None):
if gate is None:
raise ValueError("FusedRMSNormGated requires a gate tensor")
if hidden_states.device.type != "cuda":
raise ValueError(
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
)
return FusedRMSNormGatedFunction.apply(
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

View File

@@ -12,6 +12,7 @@ from torch import nn
from transformers import AutoConfig
from axolotl.kernels.lora import (
apply_lora_embedding,
apply_lora_mlp_geglu,
apply_lora_mlp_swiglu,
apply_lora_o,
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
active_adapter = model.active_adapter
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model
# Log what features are active
if lora_config.lora_dropout > 0:
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
if lora_config.bias != "none":
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
if lora_config.use_dora:
LOG.info("LoRA kernels: DoRA enabled")
# This needs to be reset after patching
original_level = LOG.getEffectiveLevel()
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
for linear_proj in ["q_proj", "k_proj", "v_proj"]
]
can_patch_qkv = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
hasattr(module, "lora_A") for module in layer_modules
)
if can_patch_qkv:
# Add optimized implementation
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention QKV projections - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention QKV projections - requires LoRA adapters"
)
if cfg.lora_o_kernel:
# Output patching
layer_modules = [
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
]
can_patch_o = all(
hasattr(module, "lora_A")
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
for module in layer_modules
)
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
if can_patch_o:
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
else:
LOG.warning_once(
"Cannot patch some attention output projection - requires LoRA "
"adapters and no lora_magnitude_vector (DoRA)"
"Cannot patch some attention output projection - requires LoRA adapters"
)
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
if cfg.lora_mlp_kernel:
# MLP patching
can_patch_mlp = all(
hasattr(proj, "lora_A")
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
for proj in (gate_proj, up_proj, down_proj)
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
)
if can_patch_mlp:
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
layer.mlp.forward = types.MethodType(apply_fn, mlp)
else:
LOG.warning_once(
"Cannot patch some MLP layers - requires LoRA adapters and no "
"lora_magnitude_vector (DoRA)"
"Cannot patch some MLP layers - requires LoRA adapters"
)
# Patch embedding layers (model-level, not per-layer)
if cfg.lora_embedding_kernel:
_patch_embedding_layers(model, cfg)
LOG.setLevel(original_level)
return model
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
"""Patch embedding layers with fused LoRA kernel.
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
"""
pretrained_model = model.model
patched = 0
# Find embedding modules - check common locations
for attr_path in [
("model", "embed_tokens"),
("model", "language_model", "embed_tokens"),
]:
parent = pretrained_model
for attr in attr_path:
parent = getattr(parent, attr, None)
if parent is None:
break
if parent is not None and hasattr(parent, "lora_embedding_A"):
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
parent.forward = types.MethodType(apply_lora_embedding, parent)
patched += 1
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
# when included in target_modules. No special embedding handling needed since
# PEFT wraps it as a Linear (not Embedding) even for tied models.
if not patched:
LOG.debug("No embedding layers with LoRA found to patch")
class FakeMLP(nn.Module):
"""
placeholder MLP for triton patching

View File

@@ -9,7 +9,6 @@ import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
)
def _is_tui_enabled(cfg: DictDefault) -> bool:
"""Check if TUI is enabled via config or environment variable."""
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
return True
tui = cfg.get("tui")
if tui is None:
return False
if isinstance(tui, bool):
return tui
if isinstance(tui, dict):
return tui.get("enabled", False)
if hasattr(tui, "enabled"):
return tui.enabled
return False
def _get_tui_config(cfg: DictDefault) -> dict:
"""Extract TUI config dict from cfg."""
tui = cfg.get("tui")
if tui is None or isinstance(tui, bool):
return {"enabled": True}
if isinstance(tui, dict):
return {**tui, "enabled": True}
if hasattr(tui, "model_dump"):
d = tui.model_dump()
d["enabled"] = True
return d
return {"enabled": True}
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
@@ -577,6 +603,37 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Register TUI callback if enabled and rank 0
tui_enabled = _is_tui_enabled(cfg)
if tui_enabled and cfg.local_rank == 0:
from axolotl.tui import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
tui_config = _get_tui_config(cfg)
tui_config_obj = (
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
)
# Reuse the early-started renderer if available (started in do_train)
early_renderer = getattr(cfg, "_tui_renderer", None)
early_queue = getattr(cfg, "_tui_queue", None)
tui_callback = AxolotlTUICallback(config=tui_config_obj)
if early_renderer is not None and early_queue is not None:
# Reuse the already-running renderer and queue
tui_callback._renderer = early_renderer
tui_callback._queue = early_queue
tui_callback._renderer_started_early = True
trainer.add_callback(tui_callback)
# Stash model info so on_train_begin can emit a single unified run_info event
tui_callback._pending_run_info = {
"model_name": cfg.base_model or "",
"training_mode": str(cfg.rl) if cfg.rl else "sft",
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
}
LOG.info("TUI dashboard enabled")
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)

View File

@@ -0,0 +1,17 @@
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
from axolotl.tui.callback import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.io_capture import LineParser, register_parser
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
__all__ = [
"AxolotlTUICallback",
"BasePanel",
"LineParser",
"TUIConfig",
"TUIState",
"register_panel",
"register_parser",
]

142
src/axolotl/tui/callback.py Normal file
View File

@@ -0,0 +1,142 @@
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
from __future__ import annotations
import logging
import queue
from transformers.trainer_callback import TrainerCallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
class _TUILogHandler(logging.Handler):
"""Logging handler that pushes log records into the TUI metric queue."""
_LEVEL_MAP = {
logging.DEBUG: "debug",
logging.INFO: "info",
logging.WARNING: "warning",
logging.ERROR: "error",
logging.CRITICAL: "error",
}
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
super().__init__()
level_name = min_level.upper()
self.setLevel(getattr(logging, level_name, logging.INFO))
self._queue = metric_queue
def emit(self, record: logging.LogRecord) -> None:
try:
level = self._LEVEL_MAP.get(record.levelno, "info")
msg = self.format(record)
self._queue.put_nowait(
{
"type": "log_line",
"level": level,
"message": msg,
}
)
except queue.Full:
pass
except Exception:
self.handleError(record)
class AxolotlTUICallback(TrainerCallback):
"""Pushes training metrics into a queue for the TUI renderer.
The callback never blocks on the render thread. The queue is bounded
(maxsize=512) with put_nowait; overflow is silently dropped.
"""
def __init__(self, config: TUIConfig):
self._config = config
self._queue: queue.Queue = queue.Queue(maxsize=4096)
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
self._log_handler: _TUILogHandler | None = None
self._renderer_started_early: bool = False
self._pending_run_info: dict | None = None
def _put(self, event: dict) -> None:
try:
self._queue.put_nowait(event)
except queue.Full:
pass
def on_train_begin(self, args, state, control, model=None, **kwargs):
# Send a single unified run_info event with all fields
run_info = {
"type": "run_info",
"run_name": getattr(args, "run_name", "") or "",
"total_steps": state.max_steps,
"total_epochs": float(args.num_train_epochs)
if args.num_train_epochs
else 1.0,
}
# Merge in model_name/training_mode/world_size if stashed by train.py
if self._pending_run_info:
run_info.update(self._pending_run_info)
self._pending_run_info = None
self._put(run_info)
if not self._renderer_started_early:
# Attach a logging handler to feed log messages into the events panel
self._log_handler = _TUILogHandler(
self._queue, min_level=self._config.log_level
)
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to both root and axolotl loggers (axolotl has propagate=False)
logging.getLogger().addHandler(self._log_handler)
logging.getLogger("axolotl").addHandler(self._log_handler)
# Start the renderer background thread
self._renderer.start()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Filter out non-numeric keys and internal keys
filtered = {}
for key, value in logs.items():
if key.startswith("_"):
continue
if isinstance(value, (int, float)):
filtered[key] = value
elif isinstance(value, str):
# HF Trainer sometimes passes string-encoded numbers
try:
filtered[key] = float(value)
except (ValueError, TypeError):
pass
if filtered:
self._put({"type": "metrics", "logs": filtered})
def on_step_end(self, args, state, control, **kwargs):
self._put(
{
"type": "step",
"step": state.global_step,
"total_steps": state.max_steps,
"epoch": state.epoch if state.epoch else 0,
}
)
def on_prediction_step(self, args, state, control, **kwargs):
pass
def on_train_end(self, args, state, control, **kwargs):
self._put({"type": "done"})
# If renderer was started early, do_train's finally block handles stop
if not self._renderer_started_early:
self._renderer.stop()
# Remove the logging handler (only if we added it)
if self._log_handler:
logging.getLogger().removeHandler(self._log_handler)
logging.getLogger("axolotl").removeHandler(self._log_handler)
self._log_handler = None

38
src/axolotl/tui/config.py Normal file
View File

@@ -0,0 +1,38 @@
"""TUI configuration — Pydantic model for TUI settings."""
from __future__ import annotations
from pydantic import BaseModel, Field
class TUIConfig(BaseModel):
"""Configuration for the Axolotl Training TUI dashboard."""
enabled: bool = Field(
default=False,
json_schema_extra={"description": "Enable the TUI dashboard"},
)
refresh_rate: int = Field(
default=4,
json_schema_extra={"description": "Renders per second"},
)
log_level: str = Field(
default="debug",
json_schema_extra={"description": "Minimum log level shown in events panel"},
)
panels: list[str] = Field(
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
json_schema_extra={"description": "Ordered list of panels to display"},
)
hardware_poll_interval: int = Field(
default=2,
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
)
stdout_log_path: str = Field(
default="axolotl_stdout.log",
json_schema_extra={"description": "File path for captured stdout/stderr log"},
)
parser_plugins: list[str] = Field(
default_factory=list,
json_schema_extra={"description": "List of extra parser classes to load"},
)

72
src/axolotl/tui/gpu.py Normal file
View File

@@ -0,0 +1,72 @@
"""GPU polling wrapper around pynvml with graceful fallback."""
from __future__ import annotations
import logging
from axolotl.tui.state import GPUStats
LOG = logging.getLogger(__name__)
_nvml_available = False
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except Exception:
LOG.debug("pynvml unavailable — GPU stats will not be shown")
class GPUPoller:
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
def __init__(self):
self._device_count = 0
if _nvml_available:
try:
self._device_count = pynvml.nvmlDeviceGetCount()
except Exception:
self._device_count = 0
@property
def available(self) -> bool:
return _nvml_available and self._device_count > 0
def poll(self) -> list[GPUStats]:
if not self.available:
return []
stats = []
for i in range(self._device_count):
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
if isinstance(name, bytes):
name = name.decode("utf-8")
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
try:
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
except Exception:
power = None
stats.append(
GPUStats(
id=i,
name=name,
util_pct=util.gpu,
vram_used_gb=mem.used / (1024**3),
vram_total_gb=mem.total / (1024**3),
temp_c=temp,
power_w=power,
)
)
except Exception:
LOG.debug("Error polling GPU device %d", i, exc_info=True)
return stats

View File

@@ -0,0 +1,196 @@
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
from __future__ import annotations
import logging
import os
import queue
import sys
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from typing import IO
# ---------------------------------------------------------------------------
# Parser registry
# ---------------------------------------------------------------------------
_parser_registry: list[type[LineParser]] = []
def register_parser(cls: type[LineParser]) -> type[LineParser]:
"""Decorator to register a LineParser subclass."""
if cls not in _parser_registry:
_parser_registry.append(cls)
return cls
def get_registered_parsers() -> list[type[LineParser]]:
return list(_parser_registry)
# ---------------------------------------------------------------------------
# Base LineParser
# ---------------------------------------------------------------------------
class LineParser(ABC):
"""Base class for stdout/stderr line parsers."""
priority: int = 50
name: str = ""
@abstractmethod
def parse(self, line: str, source: str) -> list[dict]:
"""Parse a single captured line.
Args:
line: one line of captured output, trailing newline stripped.
source: "stdout" or "stderr".
Returns:
List of event dicts to push onto the metric queue.
Return [] if this line is not relevant.
"""
...
# ---------------------------------------------------------------------------
# ParserChain
# ---------------------------------------------------------------------------
class ParserChain:
def __init__(self):
self._parsers: list[LineParser] = []
def register(self, parser: LineParser) -> None:
self._parsers.append(parser)
self._parsers.sort(key=lambda p: p.priority)
def parse(self, line: str, source: str = "stdout") -> list[dict]:
events: list[dict] = []
for parser in self._parsers:
events.extend(parser.parse(line, source))
return events
# ---------------------------------------------------------------------------
# IOCapture — OS-level fd redirect to pipe
# ---------------------------------------------------------------------------
class IOCapture:
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
passes lines through a ParserChain, and tees to a log file."""
def __init__(
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
):
self._parser_chain = parser_chain
self._queue = metric_queue
self._log_path = log_path
self._log_file: IO[str] | None = None
self._thread: threading.Thread | None = None
self._read_fd: int | None = None
self._write_fd: int | None = None
self._saved_stdout_fd: int | None = None
self._saved_stderr_fd: int | None = None
def start(self) -> None:
# Write run-start separator
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
self._log_file.write(
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
)
self._log_file.flush()
# OS-level pipe
self._read_fd, self._write_fd = os.pipe()
# Save originals
self._saved_stdout_fd = os.dup(1)
self._saved_stderr_fd = os.dup(2)
# Redirect both stdout and stderr into the write end
os.dup2(self._write_fd, 1)
os.dup2(self._write_fd, 2)
os.close(self._write_fd) # write end now held by fds 1 and 2
# Also redirect Python-level handles
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
# Drain thread
self._thread = threading.Thread(target=self._drain, daemon=True)
self._thread.start()
def stop(self) -> None:
# Restore fds — closes the write end, causing reader to see EOF
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
os.dup2(self._saved_stdout_fd, 1)
os.dup2(self._saved_stderr_fd, 2)
os.close(self._saved_stdout_fd)
os.close(self._saved_stderr_fd)
self._saved_stdout_fd = None
self._saved_stderr_fd = None
if self._thread is not None:
self._thread.join(timeout=2.0)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
"IO capture thread did not exit after 2s"
)
self._thread = None
if self._log_file is not None:
self._log_file.close()
self._log_file = None
def _drain(self) -> None:
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
# which use \r for in-place updates without \n
assert self._read_fd is not None, "_drain called before start()"
with os.fdopen(self._read_fd, "rb") as pipe:
buf = b""
while True:
chunk = pipe.read(4096)
if not chunk:
# EOF — process remaining buffer
if buf:
self._process_line(buf.decode("utf-8", errors="replace"))
break
buf += chunk
# Split on \n or \r
while b"\n" in buf or b"\r" in buf:
# Find the earliest delimiter
idx_n = buf.find(b"\n")
idx_r = buf.find(b"\r")
if idx_n == -1:
idx = idx_r
elif idx_r == -1:
idx = idx_n
else:
idx = min(idx_n, idx_r)
line = buf[:idx].decode("utf-8", errors="replace")
buf = buf[idx + 1 :]
# Handle \r\n as single delimiter
if buf.startswith(b"\n"):
buf = buf[1:]
if line:
self._process_line(line)
def _process_line(self, line: str) -> None:
line = line.rstrip()
if not line:
return
if self._log_file:
self._log_file.write(line + "\n")
self._log_file.flush()
for event in self._parser_chain.parse(line):
try:
self._queue.put_nowait(event)
except queue.Full:
pass

View File

@@ -0,0 +1,63 @@
"""Panel registry and base class for TUI panels."""
from __future__ import annotations
from abc import ABC, abstractmethod
from rich.console import RenderableType
from axolotl.tui.state import TUIState
# ---------------------------------------------------------------------------
# Panel registry
# ---------------------------------------------------------------------------
_panel_registry: dict[str, type[BasePanel]] = {}
def register_panel(position: str = "bottom", weight: int = 50):
"""Decorator to register a panel class with position and weight."""
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
cls.position = position
cls.weight = weight
_panel_registry[cls.name] = cls
return cls
return decorator
def get_registered_panels() -> dict[str, type[BasePanel]]:
return dict(_panel_registry)
# ---------------------------------------------------------------------------
# BasePanel
# ---------------------------------------------------------------------------
class BasePanel(ABC):
name: str = ""
position: str = "bottom"
weight: int = 50
min_height: int = 4
max_height: int | None = None
modes: list[str] = ["*"]
@abstractmethod
def render(self, state: TUIState) -> RenderableType:
"""Return a rich renderable. Called every tick."""
...
def on_event(self, event: dict) -> None: # noqa: B027
"""Optional: react to raw metric events before state is merged."""
pass
# Auto-import built-in panels to trigger registration
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401

View File

@@ -0,0 +1,61 @@
"""CompletionsPanel — shows recent RL/log_completions samples."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _truncate(s: str, maxlen: int = 60) -> str:
return s[:maxlen] + "" if len(s) > maxlen else s
@register_panel(position="bottom", weight=20)
class CompletionsPanel(BasePanel):
name = "completions"
min_height = 6
modes = ["grpo", "dpo"]
def render(self, state: TUIState) -> RenderableType:
if "*" not in self.modes and state.training_mode not in self.modes:
return Text("")
if not state.completions:
return Panel(
Text("No completions yet...", style="dim"),
title="Completions",
border_style="magenta",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("step", justify="right", width=6)
table.add_column("prompt", no_wrap=False, max_width=40)
table.add_column("completion", no_wrap=False, max_width=40)
table.add_column("reward", justify="right", width=8)
table.add_column("adv", justify="right", width=8)
for sample in list(state.completions)[-5:]:
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
adv_str = (
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
)
table.add_row(
str(sample.step),
_truncate(sample.prompt),
_truncate(sample.completion),
reward_str,
adv_str,
)
return Panel(table, title="Completions", border_style="magenta")

View File

@@ -0,0 +1,34 @@
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
@register_panel(position="bottom", weight=30)
class DebugPanel(BasePanel):
name = "debug"
min_height = 6
max_height = 10
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 8 debug-level log lines
debug_lines = [
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
][-8:]
for log_line in debug_lines:
ts = log_line.timestamp.strftime("%H:%M:%S")
lines.append(f"[{ts}] ", style="dim")
lines.append(log_line.message[:200], style="dim")
lines.append("\n")
if not debug_lines:
lines = Text("No debug messages yet...", style="dim")
return Panel(lines, title="Debug", border_style="dim")

View File

@@ -0,0 +1,45 @@
"""EventsPanel — scrolling log of recent events, color-coded by level."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_LEVEL_STYLES = {
"debug": "dim",
"info": "",
"warning": "yellow",
"error": "red bold",
"critical": "red bold",
}
@register_panel(position="bottom", weight=10)
class EventsPanel(BasePanel):
name = "events"
min_height = 8
max_height = 20
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 15 non-debug log lines (debug goes to DebugPanel)
recent = [
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
][-15:]
for log_line in recent:
ts = log_line.timestamp.strftime("%H:%M:%S")
level = log_line.level.upper()
style = _LEVEL_STYLES.get(log_line.level, "")
lines.append(f"[{ts}] ", style="dim")
lines.append(f"[{level}] ", style=style or "")
lines.append(log_line.message[:200], style=style or "")
lines.append("\n")
if not recent:
lines = Text("No events yet...", style="dim")
return Panel(lines, title="Events", border_style="yellow")

View File

@@ -0,0 +1,80 @@
"""HardwarePanel — per-GPU stats via pynvml."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_BAR_FULL = ""
_BAR_EMPTY = ""
def _util_bar(pct: float, width: int = 6) -> Text:
filled = int(pct / 100 * width)
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
return Text.assemble((bar, color), f" {pct:3.0f}%")
@register_panel(position="right", weight=10)
class HardwarePanel(BasePanel):
name = "hardware"
min_height = 6
def render(self, state: TUIState) -> RenderableType:
if not state.gpus:
return Panel(
Text("GPU stats unavailable", style="dim"),
title="Hardware",
border_style="green",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("id", justify="right", width=3)
table.add_column("util", no_wrap=True)
table.add_column("vram", no_wrap=True)
table.add_column("°C", justify="right", width=4)
table.add_column("W", justify="right", width=5)
total_vram_used = 0.0
total_vram_total = 0.0
total_util = 0.0
for gpu in state.gpus:
total_vram_used += gpu.vram_used_gb
total_vram_total += gpu.vram_total_gb
total_util += gpu.util_pct
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
table.add_row(
str(gpu.id),
_util_bar(gpu.util_pct),
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
str(gpu.temp_c),
power_str,
)
# Footer with aggregates
n = len(state.gpus)
if n > 1:
avg_util = total_util / n
table.add_row(
"Σ",
Text(f"avg {avg_util:.0f}%", style="dim"),
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
"",
"",
)
return Panel(table, title="Hardware", border_style="green")

View File

@@ -0,0 +1,73 @@
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
from __future__ import annotations
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, TextColumn
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _fmt_time(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "--:--:--"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
s = int(seconds) % 60
return f"{h}:{m:02d}:{s:02d}"
def _fmt_eta(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "eta --"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
if h > 0:
return f"eta {h}h{m:02d}m"
return f"eta {m}m{int(seconds) % 60:02d}s"
@register_panel(position="top", weight=10)
class ProgressPanel(BasePanel):
name = "progress"
min_height = 3
max_height = 3
def render(self, state: TUIState) -> RenderableType:
pct = (
(state.current_step / state.total_steps * 100)
if state.total_steps > 0
else 0
)
# Header line
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
header = Text.assemble(
("", "bold green"),
("AXOLOTL", "bold cyan"),
f" {mode_upper} · {model_short} ",
(
f"{state.current_step} / {state.total_steps}",
"bold",
),
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
)
# Progress bar
progress = Progress(
TextColumn(""),
BarColumn(bar_width=None),
TextColumn("{task.percentage:>3.0f}%"),
expand=True,
)
task = progress.add_task("", total=state.total_steps or 1)
progress.update(task, completed=state.current_step)
table = Table.grid(expand=True)
table.add_row(header)
table.add_row(progress)
return table

View File

@@ -0,0 +1,97 @@
"""TrainingPanel — live scalar metrics table with loss sparkline."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
# Braille sparkline characters (8 levels)
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
def _sparkline(values: list[float] | None, width: int = 20) -> str:
if not values or len(values) < 2:
return ""
vals = list(values)[-width:]
lo, hi = min(vals), max(vals)
rng = hi - lo if hi != lo else 1.0
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
# Known key ordering and formatting
_KNOWN_KEYS: list[tuple[str, str, str]] = [
("loss", "loss", ".4f"),
("grad_norm", "grad norm", ".3f"),
("learning_rate", "lr", ".2e"),
("tokens_per_second", "tok/s", ".1f"),
("samples_per_second", "samples/s", ".1f"),
("mfu", "MFU", ".1f"),
# RL-specific
("rewards_mean", "rewards/mean", ".4f"),
("rewards_std", "rewards/std", ".4f"),
("kl_divergence", "KL", ".4f"),
("clip_ratio", "clip ratio", ".3f"),
("queue_size", "queue", "d"),
]
@register_panel(position="left", weight=10)
class TrainingPanel(BasePanel):
name = "training"
min_height = 8
def render(self, state: TUIState) -> RenderableType:
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("metric", style="cyan", no_wrap=True)
table.add_column("value", justify="right")
table.add_column("trend", justify="left", no_wrap=True)
for attr, label, fmt in _KNOWN_KEYS:
val = getattr(state, attr, None)
if val is None:
# Also check extra dict
val = state.extra.get(attr)
if val is None:
continue
try:
formatted = f"{val:{fmt}}"
except (ValueError, TypeError):
formatted = str(val)
trend = ""
if attr == "loss":
trend = _sparkline(list(state.loss_history))
table.add_row(label, formatted, trend)
# Any extra keys not in _KNOWN_KEYS
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
for key, val in sorted(state.extra.items()):
if key in known_attrs or val is None:
continue
try:
formatted = f"{val:.4f}"
except (ValueError, TypeError):
formatted = str(val)
table.add_row(key, formatted, "")
if table.row_count == 0:
return Panel(
Text("Waiting for first log step...", style="dim"),
title="Training",
border_style="blue",
)
return Panel(table, title="Training", border_style="blue")

View File

@@ -0,0 +1,7 @@
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401

View File

@@ -0,0 +1,29 @@
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class DeepSpeedParser(LineParser):
priority = 20
name = "deepspeed"
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
def parse(self, line: str, source: str) -> list[dict]:
events: list[dict] = []
if m := self._SAMPLES_RE.search(line):
events.append(
{
"type": "metrics",
"logs": {"samples_per_second": float(m.group(1))},
}
)
if m := self._STAGE_RE.search(line):
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
return events

View File

@@ -0,0 +1,27 @@
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class NCCLErrorParser(LineParser):
priority = 10
name = "nccl_error"
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "error",
"message": f"⚠ NCCL: {line}",
},
{"type": "alert", "severity": "error", "message": line},
]
return []

View File

@@ -0,0 +1,37 @@
"""RawLogParser — catches every line as a log_line event."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class RawLogParser(LineParser):
priority = 99
name = "raw_log"
_LOG_RE = re.compile(
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
r"\s*[-]\s*(?P<msg>.+)$",
re.IGNORECASE,
)
# Filter out tqdm progress bar lines and other noisy output
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
_EMPTY_RE = re.compile(r"^\s*$")
def parse(self, line: str, source: str) -> list[dict]:
# Skip empty lines and tqdm progress bar updates
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
return []
m = self._LOG_RE.match(line)
level = (
m.group("level").lower()
if m
else ("error" if source == "stderr" else "info")
)
return [{"type": "log_line", "level": level, "message": line}]

View File

@@ -0,0 +1,26 @@
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TorchCompileParser(LineParser):
priority = 20
name = "torch_compile"
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "warning",
"message": f"⚡ compile: {line}",
}
]
return []

View File

@@ -0,0 +1,86 @@
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TqdmParser(LineParser):
priority = 15
name = "tqdm"
# Match tqdm-style progress lines, e.g.:
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
# 0%| | 0/30 [00:00<?, ?it/s]
_TQDM_RE = re.compile(
r"(?P<desc>.*?)\s*"
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
r"\s*\[(?P<elapsed>[^\]]*)\]"
)
# Also match simpler forms like:
# Fetching 0 files: 0it [00:00, ?it/s]
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
def parse(self, line: str, source: str) -> list[dict]:
m = self._TQDM_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
pct = int(m.group("pct"))
current = int(m.group("current").replace(",", ""))
total = int(m.group("total").replace(",", ""))
events: list[dict] = []
# Surface as a log line with progress info
if pct == 100 or pct == 0 or pct % 25 == 0:
msg = (
f"[{desc}] {pct}% ({current}/{total})"
if desc
else f"{pct}% ({current}/{total})"
)
events.append(
{
"type": "log_line",
"level": "info",
"message": msg,
}
)
# Also emit as a progress metric
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "progress"
events.append(
{
"type": "metrics",
"logs": {
f"progress/{cleaned_desc}": pct / 100.0,
},
}
)
return events
# Fallback: try simpler fetch-style progress lines
m = self._FETCH_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
current = int(m.group("current"))
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "fetch"
return [
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {current}" if desc else f"{current}",
}
]
return []

449
src/axolotl/tui/renderer.py Normal file
View File

@@ -0,0 +1,449 @@
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
from __future__ import annotations
import logging
import queue
import threading
import time
from datetime import datetime
from typing import Any
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from axolotl.tui.config import TUIConfig
from axolotl.tui.gpu import GPUPoller
from axolotl.tui.io_capture import (
IOCapture,
ParserChain,
get_registered_parsers,
)
from axolotl.tui.panels import BasePanel, get_registered_panels
from axolotl.tui.state import CompletionSample, LogLine, TUIState
LOG = logging.getLogger(__name__)
class TUIRenderer:
"""Background thread that renders the TUI dashboard using rich.live.Live."""
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
self._config = config
self._queue = metric_queue
self._state = TUIState()
self._gpu_poller = GPUPoller()
self._panels: list[BasePanel] = []
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._io_capture: IOCapture | None = None
self._parser_chain: ParserChain | None = None
def _init_panels(self) -> None:
registry = get_registered_panels()
for panel_name in self._config.panels:
if panel_name in registry:
self._panels.append(registry[panel_name]())
def _init_parser_chain(self) -> None:
# Ensure built-in parsers are imported so @register_parser decorators fire
import axolotl.tui.parsers # noqa: F401
self._parser_chain = ParserChain()
# Register all built-in parsers
for parser_cls in get_registered_parsers():
self._parser_chain.register(parser_cls())
# Load plugin parsers
for plugin_spec in self._config.parser_plugins:
try:
if "::" in plugin_spec:
# file path :: class name
file_path, class_name = plugin_spec.split("::", 1)
import importlib.util
spec = importlib.util.spec_from_file_location(
"custom_parser", file_path
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec for {file_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
parser_cls = getattr(mod, class_name)
else:
# dotted module path
module_path, class_name = plugin_spec.rsplit(".", 1)
mod = importlib.import_module(module_path)
parser_cls = getattr(mod, class_name)
self._parser_chain.register(parser_cls())
except Exception as exc:
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
def _build_layout(self) -> Layout:
layout = Layout()
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
sections = []
if top_panels:
layout_top = Layout(name="top", size=3)
sections.append(layout_top)
if left_panels or right_panels:
layout_middle = Layout(name="middle", ratio=3)
middle_parts = []
if left_panels:
middle_parts.append(Layout(name="left", ratio=1))
if right_panels:
middle_parts.append(Layout(name="right", ratio=1))
if middle_parts:
layout_middle.split_row(*middle_parts)
sections.append(layout_middle)
if bottom_panels:
layout_bottom = Layout(name="bottom", ratio=2)
if len(bottom_panels) > 1:
layout_bottom.split_row(
*[
Layout(name=f"bottom_{i}", ratio=1)
for i in range(len(bottom_panels))
]
)
sections.append(layout_bottom)
if sections:
layout.split_column(*sections)
return layout
def _update_layout(self, layout: Layout) -> None:
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
if top_panels:
layout["top"].update(top_panels[0].render(self._state))
if left_panels:
layout["left"].update(left_panels[0].render(self._state))
if right_panels:
layout["right"].update(right_panels[0].render(self._state))
if bottom_panels:
if len(bottom_panels) == 1:
layout["bottom"].update(bottom_panels[0].render(self._state))
else:
for i, panel in enumerate(bottom_panels):
layout[f"bottom_{i}"].update(panel.render(self._state))
def _drain_queue(self) -> None:
while True:
try:
event = self._queue.get_nowait()
except queue.Empty:
break
# Dispatch event to panels first
for panel in self._panels:
panel.on_event(event)
event_type = event.get("type")
if event_type == "metrics":
logs = event.get("logs", {})
self._apply_metrics(logs)
elif event_type == "step":
self._state.current_step = event.get("step", self._state.current_step)
self._state.total_steps = event.get(
"total_steps", self._state.total_steps
)
self._state.current_epoch = event.get(
"epoch", self._state.current_epoch
)
now = time.time()
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
if self._state.current_step > 0 and self._state.total_steps > 0:
rate = self._state.elapsed_seconds / self._state.current_step
remaining = self._state.total_steps - self._state.current_step
self._state.eta_seconds = rate * remaining
elif event_type == "log_line":
level = event.get("level", "info")
message = event.get("message", "")
self._state.log_lines.append(
LogLine(
timestamp=datetime.now(),
level=level,
message=message,
)
)
elif event_type == "completion":
self._state.completions.append(
CompletionSample(
step=event.get("step", 0),
prompt=event.get("prompt", ""),
completion=event.get("completion", ""),
reward=event.get("reward"),
advantage=event.get("advantage"),
)
)
elif event_type == "run_info":
if "run_name" in event:
self._state.run_name = event["run_name"]
if "model_name" in event:
self._state.model_name = event["model_name"]
if "training_mode" in event:
self._state.training_mode = event["training_mode"]
if "world_size" in event:
self._state.world_size = event["world_size"]
if "total_steps" in event:
self._state.total_steps = event["total_steps"]
if "total_epochs" in event:
self._state.total_epochs = event["total_epochs"]
if "zero_stage" in event:
self._state.zero_stage = event["zero_stage"]
elif event_type == "done":
self._stop_event.set()
def _apply_metrics(self, logs: dict[str, Any]) -> None:
metric_map = {
"loss": "loss",
"grad_norm": "grad_norm",
"learning_rate": "learning_rate",
"tokens_per_second": "tokens_per_second",
"samples_per_second": "samples_per_second",
"mfu": "mfu",
"rewards/mean": "rewards_mean",
"rewards_mean": "rewards_mean",
"rewards/std": "rewards_std",
"rewards_std": "rewards_std",
"kl": "kl_divergence",
"kl_divergence": "kl_divergence",
"clip_ratio": "clip_ratio",
"queue_size": "queue_size",
}
for key, value in logs.items():
if key in metric_map:
setattr(self._state, metric_map[key], value)
else:
self._state.extra[key] = value
if "loss" in logs and logs["loss"] is not None:
self._state.loss_history.append(logs["loss"])
def start(self) -> None:
self._init_panels()
self._init_parser_chain()
# Set up I/O capture
assert self._parser_chain is not None, "_init_parser_chain must be called first"
self._io_capture = IOCapture(
log_path=self._config.stdout_log_path,
parser_chain=self._parser_chain,
metric_queue=self._queue,
)
# Monkeypatch tqdm to suppress terminal output and route through our queue.
# This prevents tqdm progress bars from flickering through the TUI and
# ensures all progress events appear in the Events panel.
self._install_tqdm_hook()
self._io_capture_ready = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._io_capture_ready.wait(timeout=5.0)
def _install_tqdm_hook(self) -> None:
"""Replace tqdm's display method to route updates through TUI queue."""
try:
import io
import tqdm
import tqdm.auto
q = self._queue
self._tqdm_parser = None
# Find our tqdm parser in the chain
for p in self._parser_chain._parsers if self._parser_chain else []:
if p.name == "tqdm":
self._tqdm_parser = p
break
# Save originals for restore
self._orig_tqdm_class_auto = tqdm.auto.tqdm
self._orig_tqdm_class_tqdm = tqdm.tqdm
self._orig_tqdm_class_std = tqdm.std.tqdm
class TUITqdm(tqdm.tqdm):
"""tqdm subclass that sends progress to TUI instead of terminal."""
def __init__(self, *args, **kwargs):
# Force output to devnull so nothing reaches the terminal
kwargs["file"] = io.StringIO()
kwargs["dynamic_ncols"] = False
kwargs["ncols"] = 80
super().__init__(*args, **kwargs)
def display(self, msg=None, pos=None):
# Build a progress string and push to queue
if self.total and self.total > 0:
pct = self.n / self.total * 100
desc = self.desc.rstrip(": ") if self.desc else ""
# Emit events at milestones or at low frequency
is_milestone = (
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
)
if is_milestone:
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
if desc
else f"{pct:.0f}% ({self.n}/{self.total})",
}
)
except Exception:
pass
try:
metric_key = (
f"progress/{desc.lower().replace(' ', '_')}"
if desc
else "progress/unknown"
)
q.put_nowait(
{
"type": "metrics",
"logs": {metric_key: pct / 100.0},
}
)
except Exception:
pass
def close(self):
# Emit final completion event
if self.total and self.total > 0 and self.n > 0:
desc = self.desc.rstrip(": ") if self.desc else ""
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
if desc
else f"100% ({self.total}/{self.total}) done",
}
)
except Exception:
pass
super().close()
# Replace tqdm globally
tqdm.auto.tqdm = TUITqdm
tqdm.tqdm = TUITqdm
# Also patch tqdm.std which some libraries use directly
tqdm.std.tqdm = TUITqdm
self._tui_tqdm_cls = TUITqdm
except Exception as exc:
LOG.debug(f"Failed to install tqdm hook: {exc}")
def _uninstall_tqdm_hook(self) -> None:
"""Restore original tqdm."""
try:
import tqdm
import tqdm.auto
if hasattr(self, "_orig_tqdm_class_auto"):
tqdm.auto.tqdm = self._orig_tqdm_class_auto
if hasattr(self, "_orig_tqdm_class_tqdm"):
tqdm.tqdm = self._orig_tqdm_class_tqdm
if hasattr(self, "_orig_tqdm_class_std"):
tqdm.std.tqdm = self._orig_tqdm_class_std
except Exception:
pass
def stop(self) -> None:
self._stop_event.set()
self._uninstall_tqdm_hook()
if self._thread is not None:
self._thread.join(timeout=5.0)
def _run(self) -> None:
import os
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
# This ensures rich.live.Live writes to the terminal, not the pipe.
saved_tty_fd = os.dup(1)
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
console = Console(file=tty_file)
layout = self._build_layout()
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
gpu_poll_counter = 0
gpu_poll_ticks = max(
1, int(self._config.hardware_poll_interval / tick_interval)
)
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
if self._io_capture:
self._io_capture.start()
# Signal that IO capture is live so start() can return
if hasattr(self, "_io_capture_ready"):
self._io_capture_ready.set()
try:
with Live(
layout,
console=console,
refresh_per_second=self._config.refresh_rate,
screen=True,
redirect_stdout=False,
redirect_stderr=False,
) as live:
while not self._stop_event.is_set():
self._drain_queue()
# Poll GPU stats periodically
gpu_poll_counter += 1
if gpu_poll_counter >= gpu_poll_ticks:
gpu_poll_counter = 0
if self._gpu_poller.available:
self._state.gpus = self._gpu_poller.poll()
# Update elapsed time
self._state.elapsed_seconds = (
time.time() - self._state.start_time.timestamp()
)
self._update_layout(layout)
live.update(layout)
time.sleep(tick_interval)
# Final drain
self._drain_queue()
self._update_layout(layout)
live.update(layout)
finally:
if self._io_capture:
self._io_capture.stop()
try:
tty_file.close()
except Exception:
pass

88
src/axolotl/tui/state.py Normal file
View File

@@ -0,0 +1,88 @@
"""TUI shared data model — dataclasses for the dashboard state."""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class GPUStats:
id: int
name: str
util_pct: float
vram_used_gb: float
vram_total_gb: float
temp_c: int
power_w: float | None
@dataclass
class LogLine:
timestamp: datetime
level: str # "info" | "debug" | "warning" | "error"
message: str
@dataclass
class CompletionSample:
step: int
prompt: str
completion: str
reward: float | None
advantage: float | None
@dataclass
class TUIState:
# Run metadata
run_name: str = ""
model_name: str = ""
training_mode: str = "sft"
world_size: int = 1
start_time: datetime = field(default_factory=datetime.now)
# Progress
current_step: int = 0
total_steps: int = 0
current_epoch: float = 0.0
total_epochs: float = 1.0
elapsed_seconds: float = 0.0
eta_seconds: float | None = None
# Training metrics (rolling window + current)
loss: float | None = None
grad_norm: float | None = None
learning_rate: float | None = None
tokens_per_second: float | None = None
samples_per_second: float | None = None
mfu: float | None = None
# RL-specific (None for non-RL modes)
rewards_mean: float | None = None
rewards_std: float | None = None
kl_divergence: float | None = None
clip_ratio: float | None = None
queue_size: int | None = None
# Per-GPU hardware (list indexed by local rank)
gpus: list[GPUStats] = field(default_factory=list)
# Recent log lines
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
# Recent completions (GRPO/SFT with log_completions)
completions: deque[CompletionSample] = field(
default_factory=lambda: deque(maxlen=20)
)
# Loss history for sparkline
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
# DeepSpeed zero stage (None if not using DeepSpeed)
zero_stage: int | None = None
# Arbitrary plugin state
extra: dict[str, Any] = field(default_factory=dict)

View File

@@ -13,6 +13,7 @@ from pydantic import (
model_validator,
)
from axolotl.tui.config import TUIConfig
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
@@ -140,6 +141,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
tui: TUIConfig | None = Field(
default=None,
json_schema_extra={
"description": "TUI dashboard configuration. Set enabled: true to activate."
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
@@ -703,6 +710,12 @@ class AxolotlInputConfig(
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_embedding_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
@@ -758,44 +771,6 @@ class AxolotlInputConfig(
llama4_linearized_experts: bool | None = None
# MoE aux-loss-free (AFB) toggles
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
default=None,
json_schema_extra={
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
},
)
moe_update_rate: float | None = Field(
default=None,
json_schema_extra={
"description": "Per-step bias update rate (gamma). Recommended: 0.0050.05. If unset, plugin default is 0.01.",
},
)
moe_update_momentum: float | None = Field(
default=None,
json_schema_extra={
"description": "EMA momentum for expert load smoothing (01). If unset, plugin default is 0.9.",
},
)
moe_bias_cap: float | None = Field(
default=None,
json_schema_extra={
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
},
)
moe_afb_warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
},
)
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
default=None,
json_schema_extra={
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
},
)
deepspeed: str | dict[str, Any] | None = Field(
default=None,
json_schema_extra={
@@ -874,12 +849,6 @@ class AxolotlInputConfig(
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
expert_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={
@@ -1357,6 +1326,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
or data.get("lora_embedding_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
@@ -1404,7 +1374,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
kernel_fields = [
"lora_mlp_kernel",
"lora_qkv_kernel",
"lora_o_kernel",
"lora_embedding_kernel",
]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
@@ -1417,10 +1392,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("trust_remote_code"):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
@@ -1442,6 +1413,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
if data.get("lora_embedding_kernel") is None:
data["lora_embedding_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "

View File

@@ -681,15 +681,7 @@ class LoRAValidationMixin:
@model_validator(mode="before")
@classmethod
def check_lora_kernels_dora(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
) and data.get("peft_use_dora"):
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
"compatible with DoRA at the moment."
)
# DoRA is now supported by lora kernels
return data
@model_validator(mode="before")
@@ -1386,14 +1378,6 @@ class ComplexValidationMixin:
self.tensor_parallel_size = 1
return self
@model_validator(mode="after")
def check_expert_parallel_size(self):
if not getattr(self, "expert_parallel_size", None):
self.expert_parallel_size = 1
elif self.expert_parallel_size < 1:
raise ValueError("expert_parallel_size must be >= 1")
return self
@model_validator(mode="after")
def check_context_parallel_size(self):
if self.sequence_parallel_degree and not self.context_parallel_size:

View File

@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
proj.base_layer = base_layer
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
# quant_state should be None since weight is bf16, not FP8
self.assertIsNone(quant_state)
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
scale_inv = torch.ones(1)
base_layer.weight_scale_inv = scale_inv
W, b, quant_state, A, B, s = get_lora_parameters(proj)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
self.assertIs(quant_state, scale_inv)

View File

@@ -102,7 +102,7 @@ def mock_proj():
def test_get_lora_parameters(mock_proj):
"""Tests get_lora_parameters function"""
# Test with LoRA enabled
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert isinstance(W, torch.Tensor)
assert W.shape == (128, 64)
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
# Test with LoRA disabled
mock_proj.disable_adapters = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None
# Test with merged state
mock_proj.disable_adapters = False
mock_proj.merged = True
W, b, _, A, B, s = get_lora_parameters(mock_proj)
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
assert A is None and B is None and s is None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,120 @@
"""Test LoRA kernels under FSDP2 multi-GPU training.
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
lora_embedding_kernel work correctly with FSDP2 sharding, including
with bias, dropout, and DoRA enabled.
"""
from pathlib import Path
import yaml
from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def _run_training(temp_dir, cfg):
"""Write config and launch multi-GPU training."""
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
def _base_lora_fsdp2_config(temp_dir, **overrides):
"""Base config for LoRA + FSDP2 + kernel tests."""
cfg = {
"base_model": "Qwen/Qwen3-0.6B",
"sequence_len": 512,
"val_set_size": 0.0,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:1%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-4,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"flash_attention": True,
"bf16": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
# Enable all LoRA kernels
"lora_mlp_kernel": True,
"lora_qkv_kernel": True,
"lora_o_kernel": True,
"lora_embedding_kernel": True,
"save_safetensors": True,
}
cfg.update(overrides)
return DictDefault(cfg)
class TestFSDP2LoRAKernels:
"""Test LoRA kernels under FSDP2."""
@require_torch_2_7_0
def test_lora_kernels_basic(self, temp_dir):
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
cfg = _base_lora_fsdp2_config(temp_dir)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dropout(self, temp_dir):
"""LoRA kernels + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora(self, temp_dir):
"""LoRA kernels + DoRA + FSDP2."""
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@require_torch_2_7_0
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
"""LoRA kernels + DoRA + dropout + FSDP2."""
cfg = _base_lora_fsdp2_config(
temp_dir,
peft_use_dora=True,
lora_dropout=0.05,
)
_run_training(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
test_configs = [
# Dropout prevents patching
# Dropout — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
# Bias — kernels now support this
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
# Verify patches ARE applied (dropout and bias are now supported)
assert (
layer.forward.__func__ is apply_lora_mlp_swiglu
or layer.forward.__func__ is apply_lora_mlp_geglu
)
def test_kernel_config_options():
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
def test_kernel_training_integration_dropout_non_zero(temp_dir):
"""Test model loading with dropout non-zero should not patch."""
"""Test model loading with dropout non-zero DOES patch (now supported)."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
# Load config
cfg = load_cfg(str(path))
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
# Apply patches — should succeed even with dropout > 0
model_loader.patch_manager._apply_self_attention_lora_patch()
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch(model)
# Verify patch was not applied
# Verify patches WERE applied (dropout is now supported by kernels)
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")
assert hasattr(self_attn, "apply_qkv")
assert hasattr(self_attn, "apply_o")

View File

@@ -1,79 +0,0 @@
"""
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestMoeAuxFree(unittest.TestCase):
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
@with_temp_dir
def test_mixtral_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# Inspect model modules for a patched MoE layer
patched = None
for m in model.modules():
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
patched = m
break
assert patched is not None, "No MoE layer patched by aux-free plugin"
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
# ensure counts buffer got reset by callback (best effort)
assert torch.all(patched._afb_counts == 0)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,83 +0,0 @@
"""
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
Checks that aux-free training loss does not degrade beyond a small tolerance.
"""
import unittest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
def _last_logged_loss(trainer) -> float | None:
# Scan log_history for the most recent entry with a 'loss' key
for entry in reversed(trainer.state.log_history):
if isinstance(entry, dict) and "loss" in entry:
return float(entry["loss"])
return None
class TestMoeAuxParity(unittest.TestCase):
@with_temp_dir
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
base_cfg = {
"base_model": "hf-internal-testing/Mixtral-tiny",
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 8,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
"seed": 42,
"logging_steps": 1,
}
# Baseline: aux-loss (gshard)
cfg0 = DictDefault(dict(base_cfg))
cfg0.output_dir = f"{temp_dir}/baseline"
cfg0 = validate_config(cfg0)
normalize_config(cfg0)
# baseline uses default aux-loss routing; no plugin registration
dataset_meta0 = load_datasets(cfg=cfg0)
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
loss0 = _last_logged_loss(trainer0)
assert loss0 is not None
# Aux-free: plugin + noaux_tc
cfg1 = DictDefault(dict(base_cfg))
cfg1.output_dir = f"{temp_dir}/auxfree"
cfg1.plugins = [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
]
cfg1.moe_balance_type = "noaux_tc"
cfg1.moe_update_rate = 0.01
cfg1.moe_update_momentum = 0.9
cfg1.moe_bias_cap = 2.0
cfg1 = validate_config(cfg1)
normalize_config(cfg1)
prepare_plugins(cfg1)
dataset_meta1 = load_datasets(cfg=cfg1)
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
loss1 = _last_logged_loss(trainer1)
assert loss1 is not None
# Assert aux-free loss is within 10% of aux-loss baseline
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"

View File

@@ -1,76 +0,0 @@
"""
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
"""
import unittest
import torch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
class TestQwen3MoeAuxFree(unittest.TestCase):
@with_temp_dir
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
"flash_attention": False,
"sequence_len": 512,
"bf16": False,
"fp16": False,
"val_set_size": 0.02,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 1e-5,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_steps": 0,
"eval_steps": 0,
"save_first_step": False,
# Aux-free plugin and toggles
"plugins": [
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
],
"moe_balance_type": "noaux_tc",
"moe_update_rate": 0.01,
"moe_update_momentum": 0.9,
"moe_bias_cap": 2.0,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
prepare_plugins(cfg)
dataset_meta = load_datasets(cfg=cfg)
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
# check that at least one sparse MoE block has been patched
found = False
for m in model.modules():
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
assert m._afb_patched is True
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
found = True
break
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
check_model_output_exists(temp_dir, cfg)

View File

@@ -12,10 +12,7 @@ from pathlib import Path
import torch
from packaging import version
try:
from tbparse import SummaryReader
except ImportError: # pragma: no cover - optional dependency
SummaryReader = None
from tbparse import SummaryReader
from axolotl.utils.dict import DictDefault
@@ -182,14 +179,12 @@ def check_tensorboard(
tag: str,
lt_val: float,
assertion_err: str,
rtol: float = 0.02,
rtol: float = 0.05,
gt_zero: bool = True,
) -> None:
"""
helper function to parse and check tensorboard logs
"""
if SummaryReader is None:
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)

View File

@@ -0,0 +1,229 @@
"""
Correctness tests for fused RMSNorm + SiLU Gate kernel.
Tests against the eager Qwen3_5RMSNormGated implementation.
"""
import pytest
import torch
import torch.nn.functional as F
pytest.importorskip("triton", reason="triton required for fused kernels")
if not torch.cuda.is_available():
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
class EagerRMSNormGated(torch.nn.Module):
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
def _sync_weights(eager_mod, fused_mod):
"""Copy weights from eager to fused module."""
fused_mod.weight.data.copy_(eager_mod.weight.data)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 128, 256),
(4, 64, 512),
(1, 32, 1024),
(2, 16, 2560), # Qwen3.5-4B hidden_size
(2, 16, 4096), # Qwen3.5-9B hidden_size
(1, 8, 5120), # Qwen3.5-27B hidden_size
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
],
)
class TestRMSNormGatedForward:
def test_output_matches_eager(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
if dtype == torch.float32:
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
else:
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
def test_output_shape(self, dtype, shape):
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
y = fused(X, gate=G)
assert y.shape == (B, T, H)
assert y.dtype == dtype
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize(
"shape",
[
(2, 32, 256),
(2, 16, 512),
(2, 16, 2560), # Qwen3.5-4B
(1, 8, 4096), # Qwen3.5-9B
(1, 8, 5120), # Qwen3.5-27B
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
],
)
class TestRMSNormGatedBackward:
def test_grad_x(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
def test_grad_gate(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
def test_grad_weight(self, dtype, shape):
torch.manual_seed(42)
B, T, H = shape
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
if dtype == torch.float32:
atol, rtol = 1e-4, 1e-4
else:
atol, rtol = 5e-2, 5e-2
torch.testing.assert_close(
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
)
class TestRMSNormGatedEdgeCases:
def test_gate_none_raises(self):
fused = FusedRMSNormGated(256).cuda()
X = torch.randn(2, 4, 256, device="cuda")
with pytest.raises(ValueError, match="requires a gate tensor"):
fused(X, gate=None)
def test_2d_input(self):
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
torch.manual_seed(42)
H = 512
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
X_ref = X.detach().clone().requires_grad_(True)
G_ref = G.detach().clone().requires_grad_(True)
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X_ref, gate=G_ref)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
grad_out = torch.randn_like(y_eager)
y_eager.backward(grad_out)
y_fused.backward(grad_out)
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
def test_random_weight_init(self):
"""Test with non-default weight values."""
torch.manual_seed(123)
H = 256
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
# Randomize weights
eager.weight.data = torch.randn_like(eager.weight.data)
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
_sync_weights(eager, fused)
y_eager = eager(X, gate=G)
y_fused = fused(X, gate=G)
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)

View File

@@ -1,267 +0,0 @@
import os
import sys
import tempfile
import unittest
from types import SimpleNamespace
import torch
import torch.distributed as dist
import torch.nn as nn
from importlib import util as importlib_util
from pathlib import Path
from huggingface_hub import snapshot_download
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
def _cfg(**overrides):
defaults = dict(
moe_balance_type="noaux_tc",
moe_update_rate=0.1,
moe_update_momentum=0.9,
moe_bias_cap=2.0,
moe_afb_warmup_steps=0,
moe_bias_sync_group="world",
expert_parallel_size=1,
)
defaults.update(overrides)
return SimpleNamespace(**defaults)
def _load_bailing_modules():
repo_dir = snapshot_download(
repo_id="inclusionAI/Ring-mini-2.0",
allow_patterns=[
"configuration_bailing_moe_v2.py",
"modeling_bailing_moe_v2.py",
"__init__.py",
],
)
repo = Path(repo_dir)
config_path = repo / "configuration_bailing_moe_v2.py"
modeling_path = repo / "modeling_bailing_moe_v2.py"
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
if config_name not in sys.modules:
spec = importlib_util.spec_from_file_location(config_name, config_path)
module = importlib_util.module_from_spec(spec)
sys.modules[config_name] = module
sys.modules["configuration_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
config_module = sys.modules[config_name]
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
if modeling_name not in sys.modules:
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
module = importlib_util.module_from_spec(spec)
sys.modules[modeling_name] = module
sys.modules["modeling_bailing_moe_v2"] = module
assert spec.loader is not None
spec.loader.exec_module(module)
modeling_module = sys.modules[modeling_name]
BailingMoeV2Config = config_module.BailingMoeV2Config
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
def _build_bailing_model():
BailingConfig, BailingBlock = _load_bailing_modules()
config = BailingConfig(
hidden_size=16,
intermediate_size=32,
moe_intermediate_size=32,
num_experts=4,
num_shared_experts=None,
num_experts_per_tok=2,
n_group=1,
topk_group=1,
routed_scaling_factor=1.0,
)
block = BailingBlock(config)
class DummyModel(nn.Module):
def __init__(self, layer):
super().__init__()
self.block = layer
self.config = SimpleNamespace(model_type="bailing_moe")
def forward(self, hidden_states):
return self.block(hidden_states)
return DummyModel(block), block
def _build_llama4_model():
from transformers import Llama4TextConfig
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
config = Llama4TextConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_attention_heads=2,
num_key_value_heads=2,
num_experts_per_tok=2,
)
layer = Llama4TextMoe(config)
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="llama4")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _build_mixtral_model():
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
config = MixtralConfig(
hidden_size=16,
intermediate_size=32,
num_local_experts=4,
num_experts_per_tok=2,
num_attention_heads=2,
num_key_value_heads=2,
)
layer = MixtralSparseMoeBlock(config)
layer.config = config
class DummyModel(nn.Module):
def __init__(self, moe_layer):
super().__init__()
self.moe = moe_layer
self.config = SimpleNamespace(model_type="mixtral")
def forward(self, hidden_states):
return self.moe(hidden_states)
return DummyModel(layer), layer
def _run_callback(plugin, cfg):
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
assert callbacks, "expected aux-free callback to be registered"
callback = callbacks[0]
dummy = SimpleNamespace()
callback.on_step_end(args=dummy, state=dummy, control=dummy)
class TestAuxFreeAdapters(unittest.TestCase):
def test_bailing_adapter_updates_counts_and_bias(self):
model, block = _build_bailing_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(block, "_afb_bias"))
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
def test_llama4_adapter_biases_router_selection(self):
model, layer = _build_llama4_model()
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertTrue(hasattr(layer, "_afb_bias"))
hidden = torch.randn(2, 4, layer.hidden_dim)
layer(hidden)
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
def test_bias_warmup_respected(self):
model, block = _build_bailing_model()
cfg = _cfg(moe_afb_warmup_steps=2)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
def _step():
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
# Warmup steps should leave bias untouched.
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
_step()
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
# Third step exceeds warmup -> bias should update.
_step()
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
def test_mixtral_adapter_respects_native_forward(self):
model, layer = _build_mixtral_model()
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
hidden_dim = layer.config.hidden_size
hidden = torch.randn(2, 3, hidden_dim)
baseline_out, baseline_logits = layer(hidden.clone())
cfg = _cfg()
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
patched_out, patched_logits = layer(hidden.clone())
self.assertTrue(torch.allclose(baseline_out, patched_out))
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
_run_callback(plugin, cfg)
def test_ep_group_resolution_deferred_until_dist_ready(self):
if dist.is_available() and dist.is_initialized():
dist.destroy_process_group()
model, block = _build_bailing_model()
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
plugin = AuxFreeMoEPlugin()
plugin.post_model_build(cfg, model)
self.assertIsNotNone(plugin._shim)
self.assertIsNone(plugin._shim.ep_group)
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
self.assertTrue(callbacks)
callback = callbacks[0]
dummy = SimpleNamespace()
tmp_init = tempfile.NamedTemporaryFile(delete=False)
tmp_init.close()
init_method = f"file://{tmp_init.name}"
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
try:
hidden = torch.randn(2, 3, block.config.hidden_size)
block(hidden)
callback.on_step_end(args=dummy, state=dummy, control=dummy)
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
finally:
dist.destroy_process_group()
os.unlink(tmp_init.name)
if __name__ == "__main__":
unittest.main()

View File

@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
result = validate_config(valid_config)
assert result["adapter"] == "lora"
with pytest.raises(ValueError, match="not compatible with DoRA"):
invalid_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
validate_config(invalid_config)
# DoRA is now compatible with lora kernels
dora_kernel_config = DictDefault(
{
"adapter": "lora",
"lora_mlp_kernel": True,
"peft_use_dora": True,
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-5,
"base_model": "dummy_model",
}
)
result = validate_config(dora_kernel_config)
assert result["lora_mlp_kernel"] is True
assert result["peft_use_dora"] is True
def test_qlora_4bit_validation(self):
"""Test QLoRA 4-bit configuration validation"""

View File

@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
mock_layer.lora_B["default"].bias = None
# Required by get_lora_parameters for dropout/DoRA extraction
mock_layer.lora_dropout = {}
mock_layer.lora_magnitude_vector = None
else:
mock_layer.weight = base_layer.weight
mock_layer.bias = base_layer.bias
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are disabled."""
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
"""Test that LoRA parameters are None when adapters are merged."""
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
"""Test parameter behavior when no adapters are present."""
layer = self.create_mock_lora_layer(has_adapters=False)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Base parameters should be returned
assert W is not None
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# All parameters should be returned
assert W is not None
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
# Check shape consistency
assert W.shape == (512, 256)
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
has_adapters=True, adapters_disabled=False, merged=False
)
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert W.dtype == self.dtype
assert b.dtype == self.dtype
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
quant_state_mock = Mock()
layer.base_layer.weight.quant_state = quant_state_mock
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert quant_state == quant_state_mock
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
layer.active_adapters = ["adapter2"]
W, b, quant_state, A, B, s = get_lora_parameters(layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
assert s == 0.2
assert torch.equal(A, layer.lora_A["adapter2"].weight)
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
model = get_peft_model(base_model, lora_config)
lora_layer = model.base_model.model.linear
# Test with adapters enabled
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is not None
assert B is not None
assert s is not None
# Test with adapters disabled
model.disable_adapter_layers()
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
assert A is None
assert B is None
assert s is None