Compare commits
5 Commits
949cdf01eb
...
textui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db6af43f3b | ||
|
|
35d06c8087 | ||
|
|
0e583efeaa | ||
|
|
b3289fd190 | ||
|
|
a67392c427 |
@@ -3,7 +3,8 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
curl --silent -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
set -o pipefail
|
||||||
|
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||||
# hf download "microsoft/Phi-4-reasoning"
|
# hf download "microsoft/Phi-4-reasoning"
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ coverage:
|
|||||||
only_pulls: false
|
only_pulls: false
|
||||||
flags: null
|
flags: null
|
||||||
paths: null
|
paths: null
|
||||||
|
informational: true
|
||||||
|
|
||||||
parsers:
|
parsers:
|
||||||
gcov:
|
gcov:
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
|||||||
type=click.Path(exists=True, path_type=str),
|
type=click.Path(exists=True, path_type=str),
|
||||||
help="YAML config for sweeping hyperparameters",
|
help="YAML config for sweeping hyperparameters",
|
||||||
)
|
)
|
||||||
|
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
@filter_none_kwargs
|
@filter_none_kwargs
|
||||||
@@ -101,6 +102,7 @@ def train(
|
|||||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||||
cloud: str | None = None,
|
cloud: str | None = None,
|
||||||
sweep: str | None = None,
|
sweep: str | None = None,
|
||||||
|
tui: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -118,6 +120,10 @@ def train(
|
|||||||
# Extract launcher args from extra args (after --)
|
# Extract launcher args from extra args (after --)
|
||||||
launcher_args = ctx.args if ctx.args else []
|
launcher_args = ctx.args if ctx.args else []
|
||||||
|
|
||||||
|
# Handle --tui flag: set env var so subprocess workers pick it up
|
||||||
|
if tui:
|
||||||
|
os.environ["AXOLOTL_TUI"] = "1"
|
||||||
|
|
||||||
# Handle Ray launcher override
|
# Handle Ray launcher override
|
||||||
_launcher = None if kwargs.get("use_ray") else launcher
|
_launcher = None if kwargs.get("use_ray") else launcher
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
import queue
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
|||||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
# Start TUI early (before data loading) so it captures preprocessing events
|
||||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
tui_renderer = None
|
||||||
if not dataset_meta:
|
tui_queue: queue.Queue | None = None
|
||||||
if cfg.rl:
|
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
if is_rank_0:
|
||||||
else:
|
from axolotl.train import _is_tui_enabled
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
if _is_tui_enabled(cfg):
|
||||||
|
import queue as _queue
|
||||||
|
|
||||||
del model, tokenizer, trainer
|
from axolotl.train import _get_tui_config
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
|
from axolotl.tui.renderer import TUIRenderer
|
||||||
|
|
||||||
gc.collect()
|
tui_config_dict = _get_tui_config(cfg)
|
||||||
|
tui_config = (
|
||||||
|
TUIConfig(**tui_config_dict)
|
||||||
|
if isinstance(tui_config_dict, dict)
|
||||||
|
else tui_config_dict
|
||||||
|
)
|
||||||
|
tui_queue = _queue.Queue(maxsize=4096)
|
||||||
|
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
# Send initial run info
|
||||||
plugin_manager.post_train_unload(cfg)
|
model_name = cfg.base_model or ""
|
||||||
|
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
try:
|
||||||
|
tui_queue.put_nowait(
|
||||||
|
{
|
||||||
|
"type": "run_info",
|
||||||
|
"model_name": model_name,
|
||||||
|
"training_mode": training_mode,
|
||||||
|
"world_size": world_size,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except _queue.Full:
|
||||||
|
pass
|
||||||
|
|
||||||
|
tui_renderer.start()
|
||||||
|
|
||||||
|
# Attach logging handler early
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from axolotl.tui.callback import _TUILogHandler
|
||||||
|
|
||||||
|
_early_log_handler = _TUILogHandler(
|
||||||
|
tui_queue, min_level=tui_config.log_level
|
||||||
|
)
|
||||||
|
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||||
|
# Attach to BOTH root and axolotl loggers because axolotl logger
|
||||||
|
# has propagate=False so root handler never sees axolotl.* messages
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.addHandler(_early_log_handler)
|
||||||
|
axolotl_logger = logging.getLogger("axolotl")
|
||||||
|
axolotl_logger.addHandler(_early_log_handler)
|
||||||
|
|
||||||
|
# Stash refs on cfg so train() can reuse the renderer
|
||||||
|
cfg._tui_renderer = tui_renderer
|
||||||
|
cfg._tui_queue = tui_queue
|
||||||
|
cfg._tui_early_log_handler = _early_log_handler
|
||||||
|
|
||||||
|
try:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||||
|
if not dataset_meta:
|
||||||
|
if cfg.rl:
|
||||||
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
else:
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
finally:
|
||||||
|
# If the TUI renderer started early but train() didn't get to stop it
|
||||||
|
# (e.g., error during data loading), clean up here
|
||||||
|
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
|
||||||
|
try:
|
||||||
|
if tui_queue is not None:
|
||||||
|
tui_queue.put_nowait({"type": "done"})
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
tui_renderer.stop()
|
||||||
|
# Remove early log handler from both root and axolotl loggers
|
||||||
|
if hasattr(cfg, "_tui_early_log_handler"):
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
|
||||||
|
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
# Aux-Loss-Free MoE Router Plugin
|
|
||||||
|
|
||||||
This integration adds an aux-loss-free (AFB) gating option to compatible MoE architectures without forking model code.
|
|
||||||
|
|
||||||
Summary
|
|
||||||
- Bias only affects expert selection (top-k); mixture weights come from unbiased logits.
|
|
||||||
- Per-expert token loads are accumulated on device and reduced across DP or EP groups.
|
|
||||||
- Bias is updated post-optimizer step outside autograd using EMA-smoothed loads.
|
|
||||||
- Existing aux loss is disabled when aux-free is enabled to avoid double signals.
|
|
||||||
|
|
||||||
Enable
|
|
||||||
- Add the plugin to your YAML, then set the aux-free toggle:
|
|
||||||
|
|
||||||
plugins:
|
|
||||||
- axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin
|
|
||||||
|
|
||||||
moe_balance_type: noaux_tc
|
|
||||||
moe_update_rate: 0.01 # default if unset
|
|
||||||
moe_update_momentum: 0.9 # default if unset
|
|
||||||
moe_bias_cap: 2.0 # default if unset
|
|
||||||
moe_afb_warmup_steps: 100 # optional
|
|
||||||
moe_bias_sync_group: world # or 'ep' if expert_parallel_size > 1
|
|
||||||
expert_parallel_size: 1 # set to your EP width when using moe_bias_sync_group: ep
|
|
||||||
|
|
||||||
Config keys
|
|
||||||
- moe_balance_type: gshard (auxiliary loss) | noaux_tc (aux-free). Default: model native.
|
|
||||||
- moe_update_rate: bias update rate (gamma). Default: 0.01.
|
|
||||||
- moe_update_momentum: EMA momentum for load smoothing. Default: 0.9.
|
|
||||||
- moe_bias_cap: absolute clamp for bias. Default: 2.0.
|
|
||||||
- moe_afb_warmup_steps: delay before applying updates. Default: 0.
|
|
||||||
- moe_bias_sync_group: reduction group for counts, 'world' (DP) or 'ep' (expert-parallel). Default: world.
|
|
||||||
- expert_parallel_size: number of ranks per expert-parallel group when using `moe_bias_sync_group: ep`. Defaults to 1 (world).
|
|
||||||
|
|
||||||
Compatibility
|
|
||||||
- Targeted families: Mixtral, Qwen3-MoE, Bailing/Ring 2.0, and Llama 4 text MoE layers.
|
|
||||||
- Pass-through: Models with native aux-free routing (e.g., DeepSeek-V3) are left unmodified; only telemetry may be added in future.
|
|
||||||
|
|
||||||
Notes
|
|
||||||
- If you also enable Liger’s aux-loss paths, the plugin neutralizes aux loss when aux-free is on.
|
|
||||||
- Telemetry: future updates will log per-expert loads and bias magnitudes.
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
"""Aux-loss-free (AFB) MoE router integration package."""
|
|
||||||
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Iterable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .core import AuxFreeShim
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LayerHandle:
|
|
||||||
layer: nn.Module
|
|
||||||
layer_idx: int
|
|
||||||
num_experts: int
|
|
||||||
top_k: int
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMoEAdapter:
|
|
||||||
"""Base adapter that discovers MoE layers and wraps their forward.
|
|
||||||
|
|
||||||
Concrete adapters should implement discovery and per-layer attribute extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
family: str = "generic"
|
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool: # pragma: no cover - thin shim
|
|
||||||
return False
|
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]: # pragma: no cover
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_top_k(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
|
||||||
return int(getattr(moe_layer, "num_experts_per_tok", getattr(moe_layer, "top_k", 2)))
|
|
||||||
|
|
||||||
def get_num_experts(self, moe_layer: nn.Module) -> int: # pragma: no cover
|
|
||||||
return int(getattr(moe_layer, "num_experts"))
|
|
||||||
|
|
||||||
def disable_aux_loss(self, model_or_layer: nn.Module) -> None:
|
|
||||||
# Best-effort: zero router aux loss coef if present
|
|
||||||
if hasattr(model_or_layer, "router_aux_loss_coef"):
|
|
||||||
try:
|
|
||||||
setattr(model_or_layer, "router_aux_loss_coef", 0.0)
|
|
||||||
except Exception: # pragma: no cover - non-critical
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _register_aux_buffers(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
|
||||||
device = next(moe_layer.parameters(), torch.tensor(0)).device
|
|
||||||
if not hasattr(moe_layer, "_afb_bias"):
|
|
||||||
moe_layer.register_buffer("_afb_bias", torch.zeros(handle.num_experts, device=device))
|
|
||||||
if not hasattr(moe_layer, "_afb_counts"):
|
|
||||||
moe_layer.register_buffer("_afb_counts", torch.zeros(handle.num_experts, device=device))
|
|
||||||
if not hasattr(moe_layer, "_afb_ema"):
|
|
||||||
moe_layer.register_buffer("_afb_ema", torch.zeros(handle.num_experts, device=device))
|
|
||||||
moe_layer._afb_layer_idx = handle.layer_idx # type: ignore[attr-defined]
|
|
||||||
moe_layer._afb_top_k = handle.top_k # type: ignore[attr-defined]
|
|
||||||
shim.register_layer_buffers(handle.layer_idx, moe_layer)
|
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
|
||||||
"""Attach per-layer buffers and mark as aux-free enabled."""
|
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
|
||||||
self._patch_forward_with_aux_free(moe_layer)
|
|
||||||
|
|
||||||
def _patch_forward_with_aux_free(self, moe_layer: nn.Module) -> None:
|
|
||||||
"""Replace the layer's forward with an aux-free gating version.
|
|
||||||
|
|
||||||
Assumes the layer exposes attributes:
|
|
||||||
- gate: linear router projecting hidden to num_experts
|
|
||||||
- num_experts: int
|
|
||||||
- experts: iterable of expert modules taking (tokens, H) -> (tokens, H)
|
|
||||||
"""
|
|
||||||
if getattr(moe_layer, "_afb_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
if not hasattr(moe_layer, "gate") or not hasattr(moe_layer, "experts"):
|
|
||||||
LOG.info("AuxFreeMoE: layer missing gate/experts; skipping forward patch")
|
|
||||||
return
|
|
||||||
|
|
||||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
|
||||||
# hidden_states: (B, T, H)
|
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
|
||||||
hs = hidden_states.view(-1, hdim)
|
|
||||||
logits = self.gate(hs)
|
|
||||||
# selection uses biased logits; weights from unbiased logits
|
|
||||||
bias = getattr(self, "_afb_bias")
|
|
||||||
top_k = int(getattr(self, "_afb_top_k", 2))
|
|
||||||
biased = logits + bias # broadcast over tokens
|
|
||||||
topk_vals, topk_idx = torch.topk(biased, k=top_k, dim=-1, sorted=False)
|
|
||||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
|
||||||
weights = torch.softmax(chosen_logits.float(), dim=-1)
|
|
||||||
weights = weights.to(hs.dtype)
|
|
||||||
|
|
||||||
# accumulate counts for bias update callback
|
|
||||||
flat_idx = topk_idx.reshape(-1)
|
|
||||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
|
||||||
getattr(self, "_afb_counts").add_(counts.to(getattr(self, "_afb_counts").dtype))
|
|
||||||
|
|
||||||
# dispatch tokens to experts
|
|
||||||
hs_rep = hs.repeat_interleave(top_k, dim=0)
|
|
||||||
y = torch.empty_like(hs_rep)
|
|
||||||
for eid in range(int(self.num_experts)):
|
|
||||||
mask = flat_idx == eid
|
|
||||||
if mask.any():
|
|
||||||
y[mask] = self.experts[eid](hs_rep[mask])
|
|
||||||
|
|
||||||
y = (y.view(-1, top_k, hdim) * weights.unsqueeze(-1)).sum(dim=1)
|
|
||||||
out = y.view(bsz, seqlen, hdim)
|
|
||||||
return (out, logits)
|
|
||||||
|
|
||||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
|
||||||
setattr(moe_layer, "_afb_patched", True)
|
|
||||||
|
|
||||||
|
|
||||||
class MixtralAdapter(BaseMoEAdapter):
|
|
||||||
family = "mixtral"
|
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
|
||||||
return getattr(getattr(model, "config", object()), "model_type", "") == "mixtral"
|
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
|
||||||
self._patch_mixtral_forward(moe_layer, shim)
|
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
|
||||||
for m in model.modules():
|
|
||||||
if m.__class__.__name__.endswith("SparseMoeBlock"):
|
|
||||||
yield m
|
|
||||||
|
|
||||||
def _patch_mixtral_forward(self, moe_layer: nn.Module, shim: AuxFreeShim) -> None:
|
|
||||||
if getattr(moe_layer, "_afb_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
shim_ref = shim
|
|
||||||
|
|
||||||
def afb_forward(self, hidden_states: torch.Tensor): # type: ignore[no-redef]
|
|
||||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
||||||
if self.training and getattr(self, "jitter_noise", 0) > 0:
|
|
||||||
hidden_states = hidden_states * torch.empty_like(hidden_states).uniform_(
|
|
||||||
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
|
|
||||||
)
|
|
||||||
flat_states = hidden_states.view(-1, hidden_dim)
|
|
||||||
router_logits = self.gate(flat_states)
|
|
||||||
|
|
||||||
layer_idx = int(getattr(self, "_afb_layer_idx", 0))
|
|
||||||
top_k = int(getattr(self, "_afb_top_k", self.top_k))
|
|
||||||
selected_experts, routing_weights = shim_ref.select_experts(layer_idx, router_logits, top_k)
|
|
||||||
routing_weights = routing_weights.to(flat_states.dtype)
|
|
||||||
|
|
||||||
flat_idx = selected_experts.reshape(-1)
|
|
||||||
counts = torch.bincount(flat_idx, minlength=int(self.num_experts))
|
|
||||||
self._afb_counts.add_(counts.to(self._afb_counts.dtype))
|
|
||||||
|
|
||||||
final_hidden_states = torch.zeros(
|
|
||||||
(batch_size * sequence_length, hidden_dim),
|
|
||||||
dtype=flat_states.dtype,
|
|
||||||
device=flat_states.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
|
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
|
||||||
for expert_idx in expert_hit:
|
|
||||||
expert_layer = self.experts[expert_idx]
|
|
||||||
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
|
|
||||||
current_state = flat_states[None, top_x].reshape(-1, hidden_dim)
|
|
||||||
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
||||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(flat_states.dtype))
|
|
||||||
|
|
||||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
||||||
return final_hidden_states, router_logits
|
|
||||||
|
|
||||||
moe_layer.forward = afb_forward.__get__(moe_layer, moe_layer.__class__) # type: ignore[attr-defined]
|
|
||||||
setattr(moe_layer, "_afb_patched", True)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Adapter(MixtralAdapter):
|
|
||||||
family = "qwen3_moe"
|
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
|
||||||
return getattr(getattr(model, "config", object()), "model_type", "") in ("qwen3_moe", "qwen2_moe")
|
|
||||||
|
|
||||||
|
|
||||||
class BailingAdapter(BaseMoEAdapter):
|
|
||||||
family = "bailing_moe"
|
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
|
||||||
model_type = getattr(getattr(model, "config", object()), "model_type", "")
|
|
||||||
return model_type in ("bailing_moe", "bailing_moe_v2")
|
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
|
||||||
for m in model.modules():
|
|
||||||
if m.__class__.__name__ == "BailingMoeV2SparseMoeBlock":
|
|
||||||
yield m
|
|
||||||
|
|
||||||
def get_num_experts(self, moe_layer: nn.Module) -> int:
|
|
||||||
if hasattr(moe_layer, "num_experts"):
|
|
||||||
return int(getattr(moe_layer, "num_experts"))
|
|
||||||
cfg = getattr(moe_layer, "config", None)
|
|
||||||
return int(getattr(cfg, "num_experts"))
|
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
|
||||||
self._patch_bailing_gate(moe_layer)
|
|
||||||
|
|
||||||
def _patch_bailing_gate(self, moe_layer: nn.Module) -> None:
|
|
||||||
gate = getattr(moe_layer, "gate", None)
|
|
||||||
if gate is None:
|
|
||||||
LOG.info("BailingAdapter: layer missing gate; skipping aux-free patch")
|
|
||||||
return
|
|
||||||
if getattr(gate, "_afb_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
def afb_gate_forward(self, hidden_states: torch.Tensor):
|
|
||||||
flat = hidden_states.view(-1, hidden_states.shape[-1])
|
|
||||||
logits = F.linear(flat.float(), self.weight.float())
|
|
||||||
scores_unbiased = torch.sigmoid(logits.float()).to(logits.dtype)
|
|
||||||
bias = getattr(moe_layer, "_afb_bias")
|
|
||||||
biased_scores = scores_unbiased + bias
|
|
||||||
topk_vals, topk_idx = self.group_limited_topk(biased_scores)
|
|
||||||
weights = torch.gather(scores_unbiased, 1, topk_idx)
|
|
||||||
if self.top_k > 1:
|
|
||||||
denom = weights.sum(dim=-1, keepdim=True).clamp_min_(1e-20)
|
|
||||||
weights = weights / denom
|
|
||||||
weights = weights * self.routed_scaling_factor
|
|
||||||
|
|
||||||
flat_topk = topk_idx.reshape(-1)
|
|
||||||
counts = torch.bincount(flat_topk, minlength=bias.numel())
|
|
||||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
|
||||||
|
|
||||||
return topk_idx, weights.to(hidden_states.dtype), logits
|
|
||||||
|
|
||||||
gate.forward = afb_gate_forward.__get__(gate, gate.__class__) # type: ignore[attr-defined]
|
|
||||||
setattr(gate, "_afb_patched", True)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4Adapter(BaseMoEAdapter):
|
|
||||||
family = "llama4"
|
|
||||||
|
|
||||||
def matches(self, model: nn.Module) -> bool:
|
|
||||||
return getattr(getattr(model, "config", object()), "model_type", "") == "llama4"
|
|
||||||
|
|
||||||
def find_moe_layers(self, model: nn.Module) -> Iterable[nn.Module]:
|
|
||||||
for m in model.modules():
|
|
||||||
if m.__class__.__name__ == "Llama4TextMoe":
|
|
||||||
yield m
|
|
||||||
|
|
||||||
def prepare(self, moe_layer: nn.Module, handle: LayerHandle, shim: AuxFreeShim) -> None:
|
|
||||||
self._register_aux_buffers(moe_layer, handle, shim)
|
|
||||||
self._patch_llama4_router(moe_layer)
|
|
||||||
|
|
||||||
def _patch_llama4_router(self, moe_layer: nn.Module) -> None:
|
|
||||||
router = getattr(moe_layer, "router", None)
|
|
||||||
if router is None:
|
|
||||||
LOG.info("Llama4Adapter: layer missing router; skipping aux-free patch")
|
|
||||||
return
|
|
||||||
if getattr(router, "_afb_patched", False):
|
|
||||||
return
|
|
||||||
|
|
||||||
def afb_router_forward(self, hidden_states: torch.Tensor):
|
|
||||||
flat = hidden_states if hidden_states.dim() == 2 else hidden_states.view(-1, hidden_states.shape[-1])
|
|
||||||
router_logits = F.linear(flat, self.weight, self.bias)
|
|
||||||
bias = getattr(moe_layer, "_afb_bias")
|
|
||||||
biased_logits = router_logits + bias
|
|
||||||
_, router_indices = torch.topk(biased_logits, self.top_k, dim=1)
|
|
||||||
unbiased_top = torch.gather(router_logits, 1, router_indices)
|
|
||||||
router_scores = torch.full_like(router_logits, float("-inf"))
|
|
||||||
router_scores.scatter_(1, router_indices, unbiased_top)
|
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(router_scores.dtype)
|
|
||||||
|
|
||||||
counts = torch.bincount(router_indices.reshape(-1), minlength=bias.numel())
|
|
||||||
getattr(moe_layer, "_afb_counts").add_(counts.to(moe_layer._afb_counts.dtype))
|
|
||||||
|
|
||||||
return router_scores, router_logits
|
|
||||||
|
|
||||||
router.forward = afb_router_forward.__get__(router, router.__class__) # type: ignore[attr-defined]
|
|
||||||
setattr(router, "_afb_patched", True)
|
|
||||||
|
|
||||||
|
|
||||||
def discover_and_prepare_layers(model: nn.Module, adapters: list[BaseMoEAdapter], shim: AuxFreeShim) -> list[LayerHandle]:
|
|
||||||
"""Discover MoE layers using the first matching adapter and attach per-layer buffers.
|
|
||||||
|
|
||||||
Returns a list of layer handles for later routing patching and updates.
|
|
||||||
"""
|
|
||||||
handles: list[LayerHandle] = []
|
|
||||||
adapter: Optional[BaseMoEAdapter] = None
|
|
||||||
for a in adapters:
|
|
||||||
if a.matches(model):
|
|
||||||
adapter = a
|
|
||||||
break
|
|
||||||
|
|
||||||
if adapter is None:
|
|
||||||
LOG.info("AuxFreeMoE: no matching adapter found; skipping aux-free routing")
|
|
||||||
return handles
|
|
||||||
|
|
||||||
# disable aux loss at model level if possible
|
|
||||||
adapter.disable_aux_loss(getattr(model, "config", model))
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
for layer in adapter.find_moe_layers(model):
|
|
||||||
try:
|
|
||||||
top_k = adapter.get_top_k(layer)
|
|
||||||
nE = adapter.get_num_experts(layer)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
handle = LayerHandle(layer=layer, layer_idx=idx, num_experts=nE, top_k=top_k)
|
|
||||||
adapter.prepare(layer, handle, shim)
|
|
||||||
handles.append(handle)
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
LOG.info(f"AuxFreeMoE: prepared {len(handles)} {adapter.family} layers for aux-free routing")
|
|
||||||
return handles
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AuxFreeConfig:
|
|
||||||
rate: float = 0.01
|
|
||||||
momentum: float = 0.9
|
|
||||||
bias_cap: float = 2.0
|
|
||||||
warmup_steps: int = 0
|
|
||||||
sync_group: str = "world" # or "ep"
|
|
||||||
|
|
||||||
|
|
||||||
class AuxFreeState:
|
|
||||||
"""Holds per-layer bias and EMA load buffers."""
|
|
||||||
|
|
||||||
def __init__(self, num_layers: int, num_experts: int, device: torch.device, cfg: AuxFreeConfig):
|
|
||||||
self.bias = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
|
||||||
self.ema_load = [torch.zeros(num_experts, device=device) for _ in range(num_layers)]
|
|
||||||
self.cfg = cfg
|
|
||||||
self.steps = 0
|
|
||||||
|
|
||||||
|
|
||||||
class AuxFreeShim:
|
|
||||||
"""Model-agnostic shim for aux-loss-free expert selection and bias updates."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
state: AuxFreeState,
|
|
||||||
ep_group: Optional[dist.ProcessGroup] = None,
|
|
||||||
ep_size: Optional[int] = None,
|
|
||||||
):
|
|
||||||
self.state = state
|
|
||||||
self.ep_group = ep_group
|
|
||||||
self._ep_size = ep_size
|
|
||||||
self._ep_group_pending = (
|
|
||||||
self.state.cfg.sync_group == "ep" and self.ep_group is None
|
|
||||||
)
|
|
||||||
self._layer_modules: dict[int, torch.nn.Module] = {}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def select_experts(self, layer_idx: int, logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""Returns (topk_indices, weights) using biased selection and unbiased weights."""
|
|
||||||
module = self._layer_modules.get(layer_idx)
|
|
||||||
if module is not None and hasattr(module, "_afb_bias"):
|
|
||||||
b = getattr(module, "_afb_bias")
|
|
||||||
else:
|
|
||||||
b = self.state.bias[layer_idx]
|
|
||||||
biased = logits + b # bias is a buffer
|
|
||||||
topk_scores, topk_idx = torch.topk(biased, k=top_k, dim=-1)
|
|
||||||
chosen_logits = torch.gather(logits, -1, topk_idx)
|
|
||||||
weights = torch.softmax(chosen_logits.float(), dim=-1).to(logits.dtype)
|
|
||||||
return topk_idx, weights
|
|
||||||
|
|
||||||
def register_layer_buffers(self, layer_idx: int, module: torch.nn.Module) -> None:
|
|
||||||
"""Bind model buffers so shim updates stay in sync with patched layers."""
|
|
||||||
self._layer_modules[layer_idx] = module
|
|
||||||
bias = getattr(module, "_afb_bias")
|
|
||||||
ema = getattr(module, "_afb_ema")
|
|
||||||
# Keep state views pointing to the same tensors to avoid drift.
|
|
||||||
if layer_idx < len(self.state.bias):
|
|
||||||
self.state.bias[layer_idx] = bias
|
|
||||||
if layer_idx < len(self.state.ema_load):
|
|
||||||
self.state.ema_load[layer_idx] = ema
|
|
||||||
|
|
||||||
def begin_step(self) -> None:
|
|
||||||
"""Call once per optimizer step before per-layer updates."""
|
|
||||||
self.state.steps += 1
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def all_reduce_counts(self, counts: torch.Tensor) -> torch.Tensor:
|
|
||||||
self._maybe_init_ep_group()
|
|
||||||
if not dist.is_available() or not dist.is_initialized():
|
|
||||||
return counts
|
|
||||||
group = self.ep_group if self.ep_group is not None else dist.group.WORLD
|
|
||||||
dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=group)
|
|
||||||
return counts
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def update_bias(self, layer_idx: int, step_counts: torch.Tensor, tokens_seen: int):
|
|
||||||
"""Apply EMA-smoothed bias update toward uniform target, with clamp and optional mean-centering."""
|
|
||||||
cfg = self.state.cfg
|
|
||||||
if self.state.steps <= cfg.warmup_steps:
|
|
||||||
return
|
|
||||||
|
|
||||||
nE = step_counts.numel()
|
|
||||||
if tokens_seen <= 0:
|
|
||||||
return
|
|
||||||
module = self._layer_modules.get(layer_idx)
|
|
||||||
if module is not None and hasattr(module, "_afb_ema"):
|
|
||||||
ema = getattr(module, "_afb_ema")
|
|
||||||
bias = getattr(module, "_afb_bias")
|
|
||||||
else:
|
|
||||||
ema = self.state.ema_load[layer_idx]
|
|
||||||
bias = self.state.bias[layer_idx]
|
|
||||||
counts = step_counts.to(ema.device)
|
|
||||||
freq = counts.float() / float(tokens_seen)
|
|
||||||
ema.mul_(cfg.momentum).add_((1.0 - cfg.momentum) * freq)
|
|
||||||
target = 1.0 / float(nE)
|
|
||||||
delta = cfg.rate * (target - ema)
|
|
||||||
# optional mean-centering to keep sum(bias) ~ 0
|
|
||||||
delta = delta - delta.mean()
|
|
||||||
bias.add_(delta)
|
|
||||||
if cfg.bias_cap is not None and cfg.bias_cap > 0:
|
|
||||||
bias.clamp_(-cfg.bias_cap, cfg.bias_cap)
|
|
||||||
|
|
||||||
def _maybe_init_ep_group(self) -> None:
|
|
||||||
if not self._ep_group_pending:
|
|
||||||
return
|
|
||||||
if not dist.is_available() or not dist.is_initialized():
|
|
||||||
return
|
|
||||||
ep_size = self._ep_size
|
|
||||||
if not ep_size or ep_size <= 1:
|
|
||||||
LOG.warning(
|
|
||||||
"AuxFreeMoE: moe_bias_sync_group='ep' requested but expert_parallel_size<=1; defaulting to world group"
|
|
||||||
)
|
|
||||||
self.ep_group = dist.group.WORLD
|
|
||||||
self._ep_group_pending = False
|
|
||||||
return
|
|
||||||
world = dist.get_world_size()
|
|
||||||
if world % ep_size != 0:
|
|
||||||
LOG.warning(
|
|
||||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world group",
|
|
||||||
ep_size,
|
|
||||||
world,
|
|
||||||
)
|
|
||||||
self.ep_group = dist.group.WORLD
|
|
||||||
self._ep_group_pending = False
|
|
||||||
return
|
|
||||||
if ep_size == world:
|
|
||||||
self.ep_group = dist.group.WORLD
|
|
||||||
else:
|
|
||||||
rank = dist.get_rank()
|
|
||||||
group_start = (rank // ep_size) * ep_size
|
|
||||||
ranks = tuple(range(group_start, group_start + ep_size))
|
|
||||||
self.ep_group = dist.new_group(ranks)
|
|
||||||
LOG.info(
|
|
||||||
"AuxFreeMoE: initialized expert-parallel reduction group (size=%s, world=%s)",
|
|
||||||
ep_size,
|
|
||||||
dist.get_world_size(),
|
|
||||||
)
|
|
||||||
self._ep_group_pending = False
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
"""Aux-loss-free MoE Router Plugin for Axolotl.
|
|
||||||
|
|
||||||
This plugin wires an aux-free gating option into compatible MoE models using
|
|
||||||
unbiased logits for mixture weights and per-expert biases for top-k selection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
from transformers.trainer_callback import TrainerCallback
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
from .adapters import (
|
|
||||||
BailingAdapter,
|
|
||||||
BaseMoEAdapter,
|
|
||||||
Llama4Adapter,
|
|
||||||
MixtralAdapter,
|
|
||||||
Qwen3Adapter,
|
|
||||||
discover_and_prepare_layers,
|
|
||||||
)
|
|
||||||
from .core import AuxFreeConfig, AuxFreeShim, AuxFreeState
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MoeAuxFreeBiasUpdateCallback(TrainerCallback):
|
|
||||||
"""Post-step callback to update aux-free biases from accumulated expert counts.
|
|
||||||
|
|
||||||
Note: The current revision expects per-layer counts to be accumulated on each
|
|
||||||
MoE layer as a buffer named `_afb_counts` during forward (to be added with
|
|
||||||
routing patches in a follow-up).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, shim: AuxFreeShim, layer_modules: list[torch.nn.Module]):
|
|
||||||
self.shim = shim
|
|
||||||
self.layer_modules = layer_modules
|
|
||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs): # noqa: D401
|
|
||||||
# Iterate prepared MoE layers and apply the bias update rule.
|
|
||||||
self.shim.begin_step()
|
|
||||||
for layer in self.layer_modules:
|
|
||||||
if not hasattr(layer, "_afb_counts") or not hasattr(layer, "_afb_layer_idx"):
|
|
||||||
continue
|
|
||||||
counts = getattr(layer, "_afb_counts")
|
|
||||||
if counts is None:
|
|
||||||
continue
|
|
||||||
counts = self.shim.all_reduce_counts(counts)
|
|
||||||
layer_idx = getattr(layer, "_afb_layer_idx", None)
|
|
||||||
if layer_idx is None:
|
|
||||||
counts.zero_()
|
|
||||||
continue
|
|
||||||
bias = getattr(layer, "_afb_bias")
|
|
||||||
counts_for_update = counts.to(bias.device)
|
|
||||||
tokens_seen = int(counts_for_update.sum().item())
|
|
||||||
# local layer-state EMA and bias update
|
|
||||||
self.shim.update_bias(layer_idx, counts_for_update, tokens_seen)
|
|
||||||
# reset step counts
|
|
||||||
counts.zero_()
|
|
||||||
return control
|
|
||||||
|
|
||||||
|
|
||||||
class AuxFreeMoEPlugin(BasePlugin):
|
|
||||||
"""Plugin that enables aux-loss-free routing when configured."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self._handles: list = []
|
|
||||||
self._shim: Optional[AuxFreeShim] = None
|
|
||||||
self._ep_group_cache: dict[tuple[int, ...], dist.ProcessGroup] = {}
|
|
||||||
|
|
||||||
def post_model_build(self, cfg, model):
|
|
||||||
# Enable only when explicitly requested
|
|
||||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
|
||||||
return
|
|
||||||
|
|
||||||
# Be conservative — skip known native aux-free families
|
|
||||||
native_auxfree = getattr(getattr(model, "config", object()), "model_type", "") in (
|
|
||||||
"deepseek_v3",
|
|
||||||
"glm4_moe",
|
|
||||||
)
|
|
||||||
if native_auxfree:
|
|
||||||
LOG.info("AuxFreeMoE: model reports native aux-free routing; skipping patching")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build aux-free state and shim
|
|
||||||
rate = cfg.moe_update_rate if cfg.moe_update_rate is not None else 0.01
|
|
||||||
momentum = (
|
|
||||||
cfg.moe_update_momentum if cfg.moe_update_momentum is not None else 0.9
|
|
||||||
)
|
|
||||||
bias_cap = cfg.moe_bias_cap if cfg.moe_bias_cap is not None else 2.0
|
|
||||||
warmup = cfg.moe_afb_warmup_steps if cfg.moe_afb_warmup_steps is not None else 0
|
|
||||||
sync_group = cfg.moe_bias_sync_group if cfg.moe_bias_sync_group else "world"
|
|
||||||
af_cfg = AuxFreeConfig(
|
|
||||||
rate=rate, momentum=momentum, bias_cap=bias_cap, warmup_steps=warmup, sync_group=sync_group
|
|
||||||
)
|
|
||||||
|
|
||||||
# Discover layers to count the number and experts for state sizing
|
|
||||||
adapters: list[BaseMoEAdapter] = [
|
|
||||||
MixtralAdapter(),
|
|
||||||
Qwen3Adapter(),
|
|
||||||
BailingAdapter(),
|
|
||||||
Llama4Adapter(),
|
|
||||||
]
|
|
||||||
|
|
||||||
# For initial state sizing, we conservatively assume the first discovered layer defines nE
|
|
||||||
n_layers = 0
|
|
||||||
n_experts = None
|
|
||||||
for m in model.modules():
|
|
||||||
n_layers += 1 # upper bound — we will re-use bias slots sparsely
|
|
||||||
device = next(model.parameters(), torch.tensor(0)).device
|
|
||||||
if n_layers <= 0:
|
|
||||||
n_layers = 1
|
|
||||||
if n_experts is None:
|
|
||||||
# we'll set a minimal placeholder; prepare() will conceptually use module buffers instead
|
|
||||||
n_experts = 2
|
|
||||||
state = AuxFreeState(num_layers=n_layers, num_experts=n_experts, device=device, cfg=af_cfg)
|
|
||||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
|
||||||
ep_group = None
|
|
||||||
if sync_group == "ep":
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
|
||||||
ep_group = self._resolve_ep_group(cfg)
|
|
||||||
else:
|
|
||||||
LOG.info(
|
|
||||||
"AuxFreeMoE: deferring expert-parallel group resolution until torch.distributed initializes"
|
|
||||||
)
|
|
||||||
self._shim = AuxFreeShim(state=state, ep_group=ep_group, ep_size=ep_size)
|
|
||||||
|
|
||||||
# Discover and prepare layers (attach per-layer buffers)
|
|
||||||
self._handles = discover_and_prepare_layers(model, adapters, self._shim)
|
|
||||||
|
|
||||||
LOG.info(
|
|
||||||
f"AuxFreeMoE: enabled with rate={rate}, momentum={momentum}, cap={bias_cap}, warmup={warmup}, group={sync_group}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _resolve_ep_group(self, cfg) -> Optional[dist.ProcessGroup]:
|
|
||||||
if not dist.is_available() or not dist.is_initialized():
|
|
||||||
LOG.warning("AuxFreeMoE: EP sync requested but torch.distributed is not initialized; defaulting to world")
|
|
||||||
return None
|
|
||||||
ep_size = getattr(cfg, "expert_parallel_size", None)
|
|
||||||
if not ep_size or ep_size <= 1:
|
|
||||||
LOG.warning("AuxFreeMoE: moe_bias_sync_group='ep' but expert_parallel_size<=1; defaulting to world")
|
|
||||||
return None
|
|
||||||
world = dist.get_world_size()
|
|
||||||
if world % ep_size != 0:
|
|
||||||
LOG.warning(
|
|
||||||
"AuxFreeMoE: expert_parallel_size %s does not divide world size %s; defaulting to world",
|
|
||||||
ep_size,
|
|
||||||
world,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
if ep_size == world:
|
|
||||||
return dist.group.WORLD
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
group_start = (rank // ep_size) * ep_size
|
|
||||||
ranks = tuple(range(group_start, group_start + ep_size))
|
|
||||||
if ranks not in self._ep_group_cache:
|
|
||||||
self._ep_group_cache[ranks] = dist.new_group(ranks)
|
|
||||||
return self._ep_group_cache[ranks]
|
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
|
||||||
if getattr(cfg, "moe_balance_type", None) != "noaux_tc":
|
|
||||||
return []
|
|
||||||
if self._shim is None:
|
|
||||||
return []
|
|
||||||
# gather concrete layer modules from handles
|
|
||||||
layers = [h.layer for h in self._handles]
|
|
||||||
cb = MoeAuxFreeBiasUpdateCallback(self._shim, layers)
|
|
||||||
LOG.info("AuxFreeMoE: registering post-step bias update callback")
|
|
||||||
return [cb]
|
|
||||||
@@ -30,6 +30,15 @@ class LigerArgs(BaseModel):
|
|||||||
|
|
||||||
liger_rope: bool | None = None
|
liger_rope: bool | None = None
|
||||||
liger_rms_norm: bool | None = None
|
liger_rms_norm: bool | None = None
|
||||||
|
liger_rms_norm_gated: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": (
|
||||||
|
"Enables fused RMSNorm+SiLU gate Triton kernel for models with "
|
||||||
|
"gated RMSNorm (e.g. Qwen3.5 / Qwen3.5 MoE linear attention layers)."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
liger_layer_norm: bool | None = None
|
liger_layer_norm: bool | None = None
|
||||||
liger_swiglu: bool | None = None
|
liger_swiglu: bool | None = None
|
||||||
liger_glu_activation: bool | None = None
|
liger_glu_activation: bool | None = None
|
||||||
|
|||||||
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
175
src/axolotl/integrations/liger/models/qwen3_5.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Liger FLCE for Qwen3.5. Based on transformers v5.3.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> CausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
# if in training mode, don't materialize logits
|
||||||
|
if self.training and (labels is not None):
|
||||||
|
loss = LigerForCausalLMLoss(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
lm_head_weight=self.lm_head.weight,
|
||||||
|
labels=labels,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # if in inference mode materialize logits
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_liger_kernel_to_qwen3_5(
|
||||||
|
cross_entropy: bool = False,
|
||||||
|
fused_linear_cross_entropy: bool = False,
|
||||||
|
rms_norm: bool = False,
|
||||||
|
rms_norm_gated: bool = False,
|
||||||
|
glu_activation: bool = False,
|
||||||
|
layer_norm: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 models.
|
||||||
|
|
||||||
|
Note: Qwen3_5RMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||||
|
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||||
|
fused_linear_cross_entropy (bool):
|
||||||
|
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||||
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||||
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||||
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||||
|
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||||
|
Qwen3_5RMSNormGated (used in linear attention layers). Default is False.
|
||||||
|
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||||
|
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import transformers.models.qwen3_5.modeling_qwen3_5 # noqa: F401
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||||
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_qwen3_5 = sys.modules["transformers.models.qwen3_5.modeling_qwen3_5"]
|
||||||
|
|
||||||
|
if rms_norm:
|
||||||
|
# Qwen3_5RMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||||
|
class LigerRMSNormForQwen3_5(LigerRMSNorm):
|
||||||
|
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
dim,
|
||||||
|
eps=eps,
|
||||||
|
offset=1.0,
|
||||||
|
casting_mode="gemma",
|
||||||
|
init_fn="zeros",
|
||||||
|
in_place=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3_5
|
||||||
|
|
||||||
|
if rms_norm_gated:
|
||||||
|
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||||
|
|
||||||
|
modeling_qwen3_5.Qwen3_5RMSNormGated = FusedRMSNormGated
|
||||||
|
|
||||||
|
if glu_activation:
|
||||||
|
|
||||||
|
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||||
|
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||||
|
config = deepcopy(config)
|
||||||
|
if intermediate_size is not None:
|
||||||
|
config.intermediate_size = intermediate_size
|
||||||
|
return LigerSwiGLUMLP(config, **kwargs)
|
||||||
|
|
||||||
|
modeling_qwen3_5.Qwen3_5MLP = _liger_swiglu_mlp_wrapper
|
||||||
|
|
||||||
|
if layer_norm:
|
||||||
|
modeling_qwen3_5.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if fused_linear_cross_entropy:
|
||||||
|
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = lce_forward
|
||||||
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
198
src/axolotl/integrations/liger/models/qwen3_5_moe.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
Liger FLCE for Qwen3.5 MoE. Based on transformers v5.3.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
|
||||||
|
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
|
||||||
|
|
||||||
|
|
||||||
|
def lce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
# if in training mode, don't materialize logits
|
||||||
|
if self.training and (labels is not None):
|
||||||
|
loss = LigerForCausalLMLoss(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
lm_head_weight=self.lm_head.weight,
|
||||||
|
labels=labels,
|
||||||
|
hidden_size=self.config.hidden_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # if in inference mode materialize logits
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits,
|
||||||
|
labels,
|
||||||
|
self.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_liger_kernel_to_qwen3_5_moe(
|
||||||
|
cross_entropy: bool = False,
|
||||||
|
fused_linear_cross_entropy: bool = False,
|
||||||
|
rms_norm: bool = False,
|
||||||
|
rms_norm_gated: bool = False,
|
||||||
|
glu_activation: bool = False,
|
||||||
|
layer_norm: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Apply Liger kernels to replace original implementation in HuggingFace Qwen3.5 MoE models.
|
||||||
|
|
||||||
|
Note: Qwen3_5MoeRMSNorm uses zero-init weight with offset 1.0 (like Gemma),
|
||||||
|
so we use LigerRMSNorm with offset=1.0 and init_fn="zeros".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
|
||||||
|
fused_linear_cross_entropy (bool):
|
||||||
|
Whether to apply Liger's fused linear cross entropy loss. Default is False.
|
||||||
|
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
|
||||||
|
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
|
||||||
|
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is False.
|
||||||
|
rms_norm_gated (bool): Whether to apply fused RMSNorm+SiLU gate kernel for
|
||||||
|
Qwen3_5MoeRMSNormGated (used in linear attention layers). Default is False.
|
||||||
|
glu_activation (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
|
||||||
|
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import transformers.models.qwen3_5_moe.modeling_qwen3_5_moe # noqa: F401
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
||||||
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||||
|
|
||||||
|
assert not (cross_entropy and fused_linear_cross_entropy), (
|
||||||
|
"cross_entropy and fused_linear_cross_entropy cannot both be True."
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_mod = sys.modules["transformers.models.qwen3_5_moe.modeling_qwen3_5_moe"]
|
||||||
|
|
||||||
|
if rms_norm:
|
||||||
|
# Qwen3_5MoeRMSNorm uses zero-init weight with `output * (1.0 + weight)` pattern
|
||||||
|
class LigerRMSNormForQwen3_5Moe(LigerRMSNorm):
|
||||||
|
def __init__(self, dim, eps=1e-6, **kwargs):
|
||||||
|
super().__init__(
|
||||||
|
dim,
|
||||||
|
eps=eps,
|
||||||
|
offset=1.0,
|
||||||
|
casting_mode="gemma",
|
||||||
|
init_fn="zeros",
|
||||||
|
in_place=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
modeling_mod.Qwen3_5MoeRMSNorm = LigerRMSNormForQwen3_5Moe
|
||||||
|
|
||||||
|
if rms_norm_gated:
|
||||||
|
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||||
|
|
||||||
|
modeling_mod.Qwen3_5MoeRMSNormGated = FusedRMSNormGated
|
||||||
|
|
||||||
|
if glu_activation:
|
||||||
|
|
||||||
|
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||||
|
"""Accepts intermediate_size to pass to LigerSwiGLUMLP"""
|
||||||
|
config = deepcopy(config)
|
||||||
|
if intermediate_size is not None:
|
||||||
|
config.intermediate_size = intermediate_size
|
||||||
|
return LigerSwiGLUMLP(config, **kwargs)
|
||||||
|
|
||||||
|
modeling_mod.Qwen3_5MoeMLP = _liger_swiglu_mlp_wrapper
|
||||||
|
|
||||||
|
if layer_norm:
|
||||||
|
modeling_mod.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
if cross_entropy:
|
||||||
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
|
|
||||||
|
if fused_linear_cross_entropy:
|
||||||
|
modeling_mod.Qwen3_5MoeForCausalLM.forward = lce_forward
|
||||||
@@ -174,6 +174,19 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_5":
|
||||||
|
from axolotl.integrations.liger.models.qwen3_5 import (
|
||||||
|
apply_liger_kernel_to_qwen3_5,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_5(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
elif cfg.model_config_type == "qwen3_moe":
|
elif cfg.model_config_type == "qwen3_moe":
|
||||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||||
apply_liger_kernel_to_qwen3_moe,
|
apply_liger_kernel_to_qwen3_moe,
|
||||||
@@ -186,6 +199,19 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
|
elif cfg.model_config_type == "qwen3_5_moe":
|
||||||
|
from axolotl.integrations.liger.models.qwen3_5_moe import (
|
||||||
|
apply_liger_kernel_to_qwen3_5_moe,
|
||||||
|
)
|
||||||
|
|
||||||
|
apply_liger_kernel_to_qwen3_5_moe(
|
||||||
|
cross_entropy=cfg.liger_cross_entropy,
|
||||||
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||||
|
glu_activation=cfg.liger_glu_activation,
|
||||||
|
rms_norm=cfg.liger_rms_norm,
|
||||||
|
rms_norm_gated=getattr(cfg, "liger_rms_norm_gated", False),
|
||||||
|
layer_norm=cfg.liger_layer_norm,
|
||||||
|
)
|
||||||
elif cfg.model_config_type == "granitemoe":
|
elif cfg.model_config_type == "granitemoe":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
||||||
|
|
||||||
|
|||||||
147
src/axolotl/kernels/dora.py
Normal file
147
src/axolotl/kernels/dora.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Triton kernels for DoRA (Weight-Decomposed Low-Rank Adaptation).
|
||||||
|
|
||||||
|
Fuses the weight norm computation and magnitude scaling to avoid
|
||||||
|
materializing the full [out_features, in_features] combined weight matrix.
|
||||||
|
The B@A product is computed row-by-row inside the kernel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from .quantize import dequantize
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _dora_fused_norm_kernel(
|
||||||
|
# Pointers
|
||||||
|
W_ptr, # base weight [out, in] (dequantized, row-major)
|
||||||
|
B_ptr, # LoRA B [out, rank] (row-major)
|
||||||
|
A_ptr, # LoRA A [rank, in] (row-major)
|
||||||
|
mag_ptr, # magnitude vector [out]
|
||||||
|
out_ptr, # output mag_norm_scale [out]
|
||||||
|
# Shapes
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
rank,
|
||||||
|
# Scaling
|
||||||
|
lora_scale, # float scaling factor
|
||||||
|
# Block sizes
|
||||||
|
BLOCK_IN: tl.constexpr,
|
||||||
|
BLOCK_R: tl.constexpr, # >= rank, power of 2
|
||||||
|
):
|
||||||
|
"""Compute mag_norm_scale[i] = magnitude[i] / ||W[i,:] + s * (B[i,:] @ A)[:] ||_2
|
||||||
|
|
||||||
|
Each program handles one output row. B[row,:] is loaded once (small),
|
||||||
|
then we tile over in_features computing the dot product with A[:,tile]
|
||||||
|
and accumulating the squared norm.
|
||||||
|
|
||||||
|
This avoids materializing the full [out, in] B@A matrix.
|
||||||
|
"""
|
||||||
|
row = tl.program_id(0)
|
||||||
|
if row >= out_features:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accumulate squared norm across tiles of in_features
|
||||||
|
norm_sq_acc = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||||
|
|
||||||
|
for start in range(0, in_features, BLOCK_IN):
|
||||||
|
cols = start + tl.arange(0, BLOCK_IN)
|
||||||
|
col_mask = cols < in_features
|
||||||
|
|
||||||
|
# Load W[row, cols]
|
||||||
|
w_vals = tl.load(
|
||||||
|
W_ptr + row * in_features + cols,
|
||||||
|
mask=col_mask,
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute (B[row,:] @ A[:, cols]) for this tile
|
||||||
|
# Load B[row, r] as scalar and A[r, cols] as vector for each r
|
||||||
|
ba_vals = tl.zeros([BLOCK_IN], dtype=tl.float32)
|
||||||
|
for r in tl.static_range(BLOCK_R):
|
||||||
|
# Load scalar B[row, r]
|
||||||
|
b_val = tl.load(
|
||||||
|
B_ptr + row * rank + r,
|
||||||
|
mask=(r < rank),
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
# Load vector A[r, cols]
|
||||||
|
a_vals = tl.load(
|
||||||
|
A_ptr + r * in_features + cols,
|
||||||
|
mask=(col_mask & (r < rank)),
|
||||||
|
other=0.0,
|
||||||
|
).to(tl.float32)
|
||||||
|
ba_vals += b_val * a_vals
|
||||||
|
|
||||||
|
# Combined: W + s * (B @ A)
|
||||||
|
combined = w_vals + lora_scale * ba_vals
|
||||||
|
|
||||||
|
# Accumulate squared values
|
||||||
|
norm_sq_acc += tl.where(col_mask, combined * combined, 0.0)
|
||||||
|
|
||||||
|
# Reduce to scalar norm
|
||||||
|
norm_sq = tl.sum(norm_sq_acc, axis=0)
|
||||||
|
norm = tl.sqrt(norm_sq + 1e-12) # epsilon for numerical stability
|
||||||
|
|
||||||
|
# Load magnitude and compute scale
|
||||||
|
mag = tl.load(mag_ptr + row).to(tl.float32)
|
||||||
|
scale = mag / norm
|
||||||
|
|
||||||
|
tl.store(out_ptr + row, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def triton_dora_scale(
|
||||||
|
W: torch.Tensor,
|
||||||
|
W_quant,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
s: float,
|
||||||
|
magnitude: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute DoRA mag_norm_scale using fused Triton kernel.
|
||||||
|
|
||||||
|
Computes B@A row-by-row inside the kernel, avoiding the full
|
||||||
|
[out_features, in_features] materialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
W: base weight [out, in] (possibly quantized)
|
||||||
|
W_quant: quantization state
|
||||||
|
A: LoRA A [rank, in]
|
||||||
|
B: LoRA B [out, rank]
|
||||||
|
s: LoRA scaling factor
|
||||||
|
magnitude: learned magnitude [out]
|
||||||
|
dtype: compute dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
mag_norm_scale: [out] tensor = magnitude / ||W + s * B @ A||_2
|
||||||
|
"""
|
||||||
|
# Dequantize W to [out, in]
|
||||||
|
W_full = dequantize(W.t(), W_quant).t().contiguous().to(dtype)
|
||||||
|
|
||||||
|
out_features, in_features = W_full.shape
|
||||||
|
rank = A.shape[0]
|
||||||
|
|
||||||
|
out = torch.empty(out_features, dtype=dtype, device=W.device)
|
||||||
|
|
||||||
|
# Block sizes
|
||||||
|
BLOCK_IN = triton.next_power_of_2(min(in_features, 2048))
|
||||||
|
BLOCK_R = triton.next_power_of_2(rank)
|
||||||
|
|
||||||
|
_dora_fused_norm_kernel[(out_features,)](
|
||||||
|
W_full,
|
||||||
|
B.contiguous().to(dtype),
|
||||||
|
A.contiguous().to(dtype),
|
||||||
|
magnitude.contiguous(),
|
||||||
|
out,
|
||||||
|
out_features=out_features,
|
||||||
|
in_features=in_features,
|
||||||
|
rank=rank,
|
||||||
|
lora_scale=s,
|
||||||
|
BLOCK_IN=BLOCK_IN,
|
||||||
|
BLOCK_R=BLOCK_R,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out.detach()
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,10 @@ def dequantize(
|
|||||||
# Extract quantization state
|
# Extract quantization state
|
||||||
if not isinstance(quant_state, list):
|
if not isinstance(quant_state, list):
|
||||||
# New style quant_state class
|
# New style quant_state class
|
||||||
|
# Non-double-quantized models have offset=None and state2=None
|
||||||
|
if quant_state.offset is None or quant_state.state2 is None:
|
||||||
|
# Fall back to bitsandbytes standard dequantize
|
||||||
|
return bnb.functional.dequantize_4bit(W, quant_state, quant_type="nf4")
|
||||||
absmax = quant_state.absmax.to(target_device)
|
absmax = quant_state.absmax.to(target_device)
|
||||||
shape = quant_state.shape
|
shape = quant_state.shape
|
||||||
dtype = quant_state.dtype
|
dtype = quant_state.dtype
|
||||||
|
|||||||
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
333
src/axolotl/kernels/rms_norm_gated.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
"""
|
||||||
|
Fused RMSNorm + SiLU Gate Triton kernel.
|
||||||
|
|
||||||
|
Computes: Y = (W + offset) * RMSNorm(X) * silu(G)
|
||||||
|
where RMSNorm(X) = X / sqrt(mean(X^2) + eps)
|
||||||
|
and silu(G) = G * sigmoid(G)
|
||||||
|
|
||||||
|
Used by Qwen3.5's GatedDeltaNet linear attention layers (Qwen3_5RMSNormGated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from liger_kernel.ops.utils import (
|
||||||
|
calculate_settings,
|
||||||
|
compare_version,
|
||||||
|
ensure_contiguous,
|
||||||
|
torch_to_triton_dtype,
|
||||||
|
)
|
||||||
|
from liger_kernel.utils import is_npu_available
|
||||||
|
|
||||||
|
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
|
||||||
|
try:
|
||||||
|
from triton.language.extra.libdevice import rsqrt
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
from triton.language.extra.cuda.libdevice import rsqrt
|
||||||
|
else:
|
||||||
|
from triton.language.math import rsqrt
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_gated_forward_kernel(
|
||||||
|
Y_ptr,
|
||||||
|
Y_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
G_ptr,
|
||||||
|
G_row_stride,
|
||||||
|
W_ptr,
|
||||||
|
W_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
offset,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Y = (W + offset) * (X / RMS(X)) * silu(G)
|
||||||
|
|
||||||
|
All computation done in fp32 (Gemma-style), result cast to input dtype.
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
X_row = tl.load(X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
G_row = tl.load(G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0)
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
|
||||||
|
|
||||||
|
X_row_dtype = X_row.dtype
|
||||||
|
|
||||||
|
# Cast everything to fp32
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
G_fp32 = G_row.to(tl.float32)
|
||||||
|
W_fp32 = W_row.to(tl.float32)
|
||||||
|
|
||||||
|
# RMS norm
|
||||||
|
mean_sq = tl.sum(X_fp32 * X_fp32, axis=0) / n_cols
|
||||||
|
rstd = rsqrt(mean_sq + eps)
|
||||||
|
tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd)
|
||||||
|
|
||||||
|
X_norm = X_fp32 * rstd
|
||||||
|
|
||||||
|
# SiLU gate: silu(G) = G * sigmoid(G)
|
||||||
|
sig_G = tl.sigmoid(G_fp32)
|
||||||
|
silu_G = G_fp32 * sig_G
|
||||||
|
|
||||||
|
# Fused output
|
||||||
|
Y_row = (offset + W_fp32) * X_norm * silu_G
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
Y_ptr + row_idx * Y_row_stride + col_offsets,
|
||||||
|
Y_row.to(X_row_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_gated_backward_kernel(
|
||||||
|
dY_ptr,
|
||||||
|
dY_row_stride,
|
||||||
|
dX_ptr,
|
||||||
|
dX_row_stride,
|
||||||
|
dG_ptr,
|
||||||
|
dG_row_stride,
|
||||||
|
X_ptr,
|
||||||
|
X_row_stride,
|
||||||
|
X_dtype: tl.constexpr,
|
||||||
|
G_ptr,
|
||||||
|
G_row_stride,
|
||||||
|
W_ptr,
|
||||||
|
W_row_stride,
|
||||||
|
RSTD_ptr,
|
||||||
|
RSTD_row_stride,
|
||||||
|
dW_ptr,
|
||||||
|
dW_row_stride,
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
offset,
|
||||||
|
rows_per_program,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Backward for Y = (W + offset) * (X * RSTD) * silu(G)
|
||||||
|
|
||||||
|
dW = sum_batch(dY * X_norm * silu(G))
|
||||||
|
dG = dY * (W + offset) * X_norm * silu'(G)
|
||||||
|
where silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||||
|
dX = RSTD * (m - (1/N) * RSTD^2 * dot(m, X) * X)
|
||||||
|
where m = dY * (W + offset) * silu(G)
|
||||||
|
"""
|
||||||
|
row_block_id = tl.program_id(0).to(tl.int64)
|
||||||
|
row_start = row_block_id * rows_per_program
|
||||||
|
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
|
||||||
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_offsets < n_cols
|
||||||
|
|
||||||
|
dW_acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||||
|
|
||||||
|
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
|
||||||
|
W_row = W_row.to(tl.float32) + offset
|
||||||
|
|
||||||
|
for row_idx in range(row_start, row_end):
|
||||||
|
dY_row = tl.load(
|
||||||
|
dY_ptr + row_idx * dY_row_stride + col_offsets, mask=mask, other=0.0
|
||||||
|
)
|
||||||
|
X_row = tl.load(
|
||||||
|
X_ptr + row_idx * X_row_stride + col_offsets, mask=mask, other=0.0
|
||||||
|
)
|
||||||
|
G_row = tl.load(
|
||||||
|
G_ptr + row_idx * G_row_stride + col_offsets, mask=mask, other=0.0
|
||||||
|
)
|
||||||
|
rstd_row = tl.load(RSTD_ptr + row_idx * RSTD_row_stride)
|
||||||
|
|
||||||
|
# Cast to fp32
|
||||||
|
dY_fp32 = dY_row.to(tl.float32)
|
||||||
|
X_fp32 = X_row.to(tl.float32)
|
||||||
|
G_fp32 = G_row.to(tl.float32)
|
||||||
|
|
||||||
|
# Recompute intermediates
|
||||||
|
X_norm = X_fp32 * rstd_row
|
||||||
|
sig_G = tl.sigmoid(G_fp32)
|
||||||
|
silu_G = G_fp32 * sig_G
|
||||||
|
|
||||||
|
# dW: accumulate dY * X_norm * silu(G)
|
||||||
|
dW_acc += dY_fp32 * X_norm * silu_G
|
||||||
|
|
||||||
|
# dG: dY * (W + offset) * X_norm * silu'(G)
|
||||||
|
# silu'(G) = sigmoid(G) * (1 + G * (1 - sigmoid(G)))
|
||||||
|
silu_prime_G = sig_G * (1.0 + G_fp32 * (1.0 - sig_G))
|
||||||
|
dG_row = dY_fp32 * W_row * X_norm * silu_prime_G
|
||||||
|
tl.store(
|
||||||
|
dG_ptr + row_idx * dG_row_stride + col_offsets,
|
||||||
|
dG_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# dX: standard RMSNorm backward with effective gradient m = dY * W * silu(G)
|
||||||
|
m = dY_fp32 * W_row * silu_G
|
||||||
|
dX_row = rstd_row * m
|
||||||
|
dX_row += rstd_row * (
|
||||||
|
-(1.0 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_fp32, axis=0) * X_fp32
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
dX_ptr + row_idx * dX_row_stride + col_offsets,
|
||||||
|
dX_row.to(X_dtype),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
dW_ptr + row_block_id * dW_row_stride + col_offsets,
|
||||||
|
dW_acc,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_gated_forward(X, G, W, eps, offset):
|
||||||
|
shape = X.shape
|
||||||
|
dim = shape[-1]
|
||||||
|
X = X.view(-1, dim)
|
||||||
|
G = G.view(-1, dim)
|
||||||
|
n_rows, n_cols = X.shape
|
||||||
|
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||||
|
|
||||||
|
Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
|
||||||
|
RSTD = torch.empty(n_rows, dtype=torch.float32, device=X.device)
|
||||||
|
|
||||||
|
assert X.shape[1] == W.shape[0], (
|
||||||
|
f"Incompatible hidden size: X.shape[1]={X.shape[1]} vs W.shape[0]={W.shape[0]}"
|
||||||
|
)
|
||||||
|
assert X.shape == G.shape, (
|
||||||
|
f"X and G must have same shape, got {X.shape} and {G.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_rms_norm_gated_forward_kernel[(n_rows,)](
|
||||||
|
Y,
|
||||||
|
Y.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
G,
|
||||||
|
G.stride(0),
|
||||||
|
W,
|
||||||
|
W.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
offset,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
return Y.view(*shape), X, G, RSTD, BLOCK_SIZE, num_warps
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_gated_backward(dY, X, G, W, RSTD, offset, BLOCK_SIZE, num_warps):
|
||||||
|
shape = dY.shape
|
||||||
|
dim = shape[-1]
|
||||||
|
dY = dY.view(-1, dim)
|
||||||
|
n_rows, n_cols = dY.shape
|
||||||
|
|
||||||
|
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
|
||||||
|
|
||||||
|
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
|
||||||
|
dX = torch.empty_like(dY)
|
||||||
|
dG = torch.empty_like(dY)
|
||||||
|
|
||||||
|
rows_per_program = math.ceil(n_rows / sm_count)
|
||||||
|
grid = (sm_count,)
|
||||||
|
|
||||||
|
_rms_norm_gated_backward_kernel[grid](
|
||||||
|
dY,
|
||||||
|
dY.stride(0),
|
||||||
|
dX,
|
||||||
|
dX.stride(0),
|
||||||
|
dG,
|
||||||
|
dG.stride(0),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
torch_to_triton_dtype[X.dtype],
|
||||||
|
G,
|
||||||
|
G.stride(0),
|
||||||
|
W,
|
||||||
|
W.stride(0),
|
||||||
|
RSTD,
|
||||||
|
RSTD.stride(0),
|
||||||
|
_dW,
|
||||||
|
_dW.stride(0),
|
||||||
|
n_rows,
|
||||||
|
n_cols,
|
||||||
|
offset,
|
||||||
|
rows_per_program,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
num_warps=num_warps,
|
||||||
|
)
|
||||||
|
|
||||||
|
dX = dX.view(*shape)
|
||||||
|
dG = dG.view(*shape)
|
||||||
|
dW = _dW.sum(dim=0).to(W.dtype)
|
||||||
|
return dX, dG, dW
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormGatedFunction(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def forward(ctx, X, G, W, eps, offset=0.0):
|
||||||
|
"""
|
||||||
|
X: (B, T, H) or (BxT, H) — input hidden states
|
||||||
|
G: (B, T, H) or (BxT, H) — gate tensor
|
||||||
|
W: (H,) — weight parameter
|
||||||
|
"""
|
||||||
|
Y, X, G, RSTD, BLOCK_SIZE, num_warps = rms_norm_gated_forward(
|
||||||
|
X, G, W, eps, offset
|
||||||
|
)
|
||||||
|
ctx.offset = offset
|
||||||
|
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||||
|
ctx.num_warps = num_warps
|
||||||
|
ctx.save_for_backward(X, G, W, RSTD)
|
||||||
|
return Y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@ensure_contiguous
|
||||||
|
def backward(ctx, dY):
|
||||||
|
X, G, W, RSTD = ctx.saved_tensors
|
||||||
|
dX, dG, dW = rms_norm_gated_backward(
|
||||||
|
dY, X, G, W, RSTD, ctx.offset, ctx.BLOCK_SIZE, ctx.num_warps
|
||||||
|
)
|
||||||
|
return dX, dG, dW, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSNormGated(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fused RMSNorm + SiLU Gate.
|
||||||
|
|
||||||
|
Computes: Y = W * RMSNorm(X) * silu(G)
|
||||||
|
|
||||||
|
Drop-in replacement for Qwen3_5RMSNormGated with matching
|
||||||
|
init signature: __init__(hidden_size, eps=1e-6, **kwargs)
|
||||||
|
and forward signature: forward(hidden_states, gate=None)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6, offset=0.0, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
def forward(self, hidden_states, gate=None):
|
||||||
|
if gate is None:
|
||||||
|
raise ValueError("FusedRMSNormGated requires a gate tensor")
|
||||||
|
if hidden_states.device.type != "cuda":
|
||||||
|
raise ValueError(
|
||||||
|
f"FusedRMSNormGated requires CUDA tensors, got device={hidden_states.device}"
|
||||||
|
)
|
||||||
|
return FusedRMSNormGatedFunction.apply(
|
||||||
|
hidden_states, gate, self.weight, self.variance_epsilon, self.offset
|
||||||
|
)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
@@ -12,6 +12,7 @@ from torch import nn
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
|
apply_lora_embedding,
|
||||||
apply_lora_mlp_geglu,
|
apply_lora_mlp_geglu,
|
||||||
apply_lora_mlp_swiglu,
|
apply_lora_mlp_swiglu,
|
||||||
apply_lora_o,
|
apply_lora_o,
|
||||||
@@ -370,13 +371,13 @@ def apply_lora_kernel_patches(
|
|||||||
active_adapter = model.active_adapter
|
active_adapter = model.active_adapter
|
||||||
lora_config = model.model.peft_config[active_adapter]
|
lora_config = model.model.peft_config[active_adapter]
|
||||||
|
|
||||||
# Only patch if conditions are met
|
# Log what features are active
|
||||||
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
|
if lora_config.lora_dropout > 0:
|
||||||
|
LOG.info(f"LoRA kernels: dropout={lora_config.lora_dropout} enabled")
|
||||||
if not can_patch:
|
if lora_config.bias != "none":
|
||||||
LOG.warning("Cannot patch layers - requires no dropout and no bias")
|
LOG.info(f"LoRA kernels: bias={lora_config.bias} enabled")
|
||||||
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
|
if lora_config.use_dora:
|
||||||
return model
|
LOG.info("LoRA kernels: DoRA enabled")
|
||||||
|
|
||||||
# This needs to be reset after patching
|
# This needs to be reset after patching
|
||||||
original_level = LOG.getEffectiveLevel()
|
original_level = LOG.getEffectiveLevel()
|
||||||
@@ -419,44 +420,33 @@ def apply_lora_kernel_patches(
|
|||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
||||||
]
|
]
|
||||||
can_patch_qkv = all(
|
can_patch_qkv = all(
|
||||||
hasattr(module, "lora_A")
|
hasattr(module, "lora_A") for module in layer_modules
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_qkv:
|
if can_patch_qkv:
|
||||||
# Add optimized implementation
|
|
||||||
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
self_attn.apply_qkv = types.MethodType(apply_lora_qkv, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention QKV projections - requires LoRA "
|
"Cannot patch some attention QKV projections - requires LoRA adapters"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
if cfg.lora_o_kernel:
|
if cfg.lora_o_kernel:
|
||||||
# Output patching
|
# Output patching
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
]
|
]
|
||||||
can_patch_o = all(
|
can_patch_o = all(hasattr(module, "lora_A") for module in layer_modules)
|
||||||
hasattr(module, "lora_A")
|
|
||||||
and len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if can_patch_o:
|
if can_patch_o:
|
||||||
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
self_attn.apply_o = types.MethodType(apply_lora_o, self_attn)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some attention output projection - requires LoRA "
|
"Cannot patch some attention output projection - requires LoRA adapters"
|
||||||
"adapters and no lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
for gate_proj, up_proj, down_proj, mlp in find_mlp_in_layer(layer):
|
||||||
if cfg.lora_mlp_kernel:
|
if cfg.lora_mlp_kernel:
|
||||||
# MLP patching
|
# MLP patching
|
||||||
can_patch_mlp = all(
|
can_patch_mlp = all(
|
||||||
hasattr(proj, "lora_A")
|
hasattr(proj, "lora_A") for proj in (gate_proj, up_proj, down_proj)
|
||||||
and len(getattr(proj, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for proj in (gate_proj, up_proj, down_proj)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_patch_mlp:
|
if can_patch_mlp:
|
||||||
@@ -464,15 +454,50 @@ def apply_lora_kernel_patches(
|
|||||||
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
layer.mlp.forward = types.MethodType(apply_fn, mlp)
|
||||||
else:
|
else:
|
||||||
LOG.warning_once(
|
LOG.warning_once(
|
||||||
"Cannot patch some MLP layers - requires LoRA adapters and no "
|
"Cannot patch some MLP layers - requires LoRA adapters"
|
||||||
"lora_magnitude_vector (DoRA)"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Patch embedding layers (model-level, not per-layer)
|
||||||
|
if cfg.lora_embedding_kernel:
|
||||||
|
_patch_embedding_layers(model, cfg)
|
||||||
|
|
||||||
LOG.setLevel(original_level)
|
LOG.setLevel(original_level)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_embedding_layers(model: PeftModelForCausalLM, cfg: DictDefault):
|
||||||
|
"""Patch embedding layers with fused LoRA kernel.
|
||||||
|
|
||||||
|
Handles both embed_tokens (nn.Embedding with lora_embedding_A/B) and
|
||||||
|
lm_head (nn.Linear with lora_A/B, used when tied embeddings are untied by PEFT).
|
||||||
|
"""
|
||||||
|
pretrained_model = model.model
|
||||||
|
patched = 0
|
||||||
|
|
||||||
|
# Find embedding modules - check common locations
|
||||||
|
for attr_path in [
|
||||||
|
("model", "embed_tokens"),
|
||||||
|
("model", "language_model", "embed_tokens"),
|
||||||
|
]:
|
||||||
|
parent = pretrained_model
|
||||||
|
for attr in attr_path:
|
||||||
|
parent = getattr(parent, attr, None)
|
||||||
|
if parent is None:
|
||||||
|
break
|
||||||
|
if parent is not None and hasattr(parent, "lora_embedding_A"):
|
||||||
|
LOG.info(f"Patching embedding layer: {'.'.join(attr_path)}")
|
||||||
|
parent.forward = types.MethodType(apply_lora_embedding, parent)
|
||||||
|
patched += 1
|
||||||
|
|
||||||
|
# lm_head with LoRA is a Linear layer - already handled by LoRA_O/LoRA_W kernels
|
||||||
|
# when included in target_modules. No special embedding handling needed since
|
||||||
|
# PEFT wraps it as a Linear (not Embedding) even for tied models.
|
||||||
|
|
||||||
|
if not patched:
|
||||||
|
LOG.debug("No embedding layers with LoRA found to patch")
|
||||||
|
|
||||||
|
|
||||||
class FakeMLP(nn.Module):
|
class FakeMLP(nn.Module):
|
||||||
"""
|
"""
|
||||||
placeholder MLP for triton patching
|
placeholder MLP for triton patching
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import typing
|
|
||||||
import weakref
|
import weakref
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
|
|||||||
from axolotl.utils.train import determine_last_checkpoint
|
from axolotl.utils.train import determine_last_checkpoint
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
if typing.TYPE_CHECKING:
|
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||||
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
|
|||||||
def setup_model_and_trainer(
|
def setup_model_and_trainer(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
Trainer,
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tui_enabled(cfg: DictDefault) -> bool:
|
||||||
|
"""Check if TUI is enabled via config or environment variable."""
|
||||||
|
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
|
||||||
|
return True
|
||||||
|
tui = cfg.get("tui")
|
||||||
|
if tui is None:
|
||||||
|
return False
|
||||||
|
if isinstance(tui, bool):
|
||||||
|
return tui
|
||||||
|
if isinstance(tui, dict):
|
||||||
|
return tui.get("enabled", False)
|
||||||
|
if hasattr(tui, "enabled"):
|
||||||
|
return tui.enabled
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tui_config(cfg: DictDefault) -> dict:
|
||||||
|
"""Extract TUI config dict from cfg."""
|
||||||
|
tui = cfg.get("tui")
|
||||||
|
if tui is None or isinstance(tui, bool):
|
||||||
|
return {"enabled": True}
|
||||||
|
if isinstance(tui, dict):
|
||||||
|
return {**tui, "enabled": True}
|
||||||
|
if hasattr(tui, "model_dump"):
|
||||||
|
d = tui.model_dump()
|
||||||
|
d["enabled"] = True
|
||||||
|
return d
|
||||||
|
return {"enabled": True}
|
||||||
|
|
||||||
|
|
||||||
@send_errors
|
@send_errors
|
||||||
def train(
|
def train(
|
||||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||||
@@ -577,6 +603,37 @@ def train(
|
|||||||
processor,
|
processor,
|
||||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||||
|
|
||||||
|
# Register TUI callback if enabled and rank 0
|
||||||
|
tui_enabled = _is_tui_enabled(cfg)
|
||||||
|
if tui_enabled and cfg.local_rank == 0:
|
||||||
|
from axolotl.tui import AxolotlTUICallback
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
|
|
||||||
|
tui_config = _get_tui_config(cfg)
|
||||||
|
tui_config_obj = (
|
||||||
|
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reuse the early-started renderer if available (started in do_train)
|
||||||
|
early_renderer = getattr(cfg, "_tui_renderer", None)
|
||||||
|
early_queue = getattr(cfg, "_tui_queue", None)
|
||||||
|
|
||||||
|
tui_callback = AxolotlTUICallback(config=tui_config_obj)
|
||||||
|
if early_renderer is not None and early_queue is not None:
|
||||||
|
# Reuse the already-running renderer and queue
|
||||||
|
tui_callback._renderer = early_renderer
|
||||||
|
tui_callback._queue = early_queue
|
||||||
|
tui_callback._renderer_started_early = True
|
||||||
|
trainer.add_callback(tui_callback)
|
||||||
|
|
||||||
|
# Stash model info so on_train_begin can emit a single unified run_info event
|
||||||
|
tui_callback._pending_run_info = {
|
||||||
|
"model_name": cfg.base_model or "",
|
||||||
|
"training_mode": str(cfg.rl) if cfg.rl else "sft",
|
||||||
|
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
}
|
||||||
|
LOG.info("TUI dashboard enabled")
|
||||||
|
|
||||||
# Handle untrained tokens if configured
|
# Handle untrained tokens if configured
|
||||||
train_dataset = dataset_meta.train_dataset
|
train_dataset = dataset_meta.train_dataset
|
||||||
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
||||||
|
|||||||
17
src/axolotl/tui/__init__.py
Normal file
17
src/axolotl/tui/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
|
||||||
|
|
||||||
|
from axolotl.tui.callback import AxolotlTUICallback
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AxolotlTUICallback",
|
||||||
|
"BasePanel",
|
||||||
|
"LineParser",
|
||||||
|
"TUIConfig",
|
||||||
|
"TUIState",
|
||||||
|
"register_panel",
|
||||||
|
"register_parser",
|
||||||
|
]
|
||||||
142
src/axolotl/tui/callback.py
Normal file
142
src/axolotl/tui/callback.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
|
from axolotl.tui.renderer import TUIRenderer
|
||||||
|
|
||||||
|
|
||||||
|
class _TUILogHandler(logging.Handler):
|
||||||
|
"""Logging handler that pushes log records into the TUI metric queue."""
|
||||||
|
|
||||||
|
_LEVEL_MAP = {
|
||||||
|
logging.DEBUG: "debug",
|
||||||
|
logging.INFO: "info",
|
||||||
|
logging.WARNING: "warning",
|
||||||
|
logging.ERROR: "error",
|
||||||
|
logging.CRITICAL: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
|
||||||
|
super().__init__()
|
||||||
|
level_name = min_level.upper()
|
||||||
|
self.setLevel(getattr(logging, level_name, logging.INFO))
|
||||||
|
self._queue = metric_queue
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
try:
|
||||||
|
level = self._LEVEL_MAP.get(record.levelno, "info")
|
||||||
|
msg = self.format(record)
|
||||||
|
self._queue.put_nowait(
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": level,
|
||||||
|
"message": msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTUICallback(TrainerCallback):
|
||||||
|
"""Pushes training metrics into a queue for the TUI renderer.
|
||||||
|
|
||||||
|
The callback never blocks on the render thread. The queue is bounded
|
||||||
|
(maxsize=512) with put_nowait; overflow is silently dropped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: TUIConfig):
|
||||||
|
self._config = config
|
||||||
|
self._queue: queue.Queue = queue.Queue(maxsize=4096)
|
||||||
|
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
|
||||||
|
self._log_handler: _TUILogHandler | None = None
|
||||||
|
self._renderer_started_early: bool = False
|
||||||
|
self._pending_run_info: dict | None = None
|
||||||
|
|
||||||
|
def _put(self, event: dict) -> None:
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait(event)
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
# Send a single unified run_info event with all fields
|
||||||
|
run_info = {
|
||||||
|
"type": "run_info",
|
||||||
|
"run_name": getattr(args, "run_name", "") or "",
|
||||||
|
"total_steps": state.max_steps,
|
||||||
|
"total_epochs": float(args.num_train_epochs)
|
||||||
|
if args.num_train_epochs
|
||||||
|
else 1.0,
|
||||||
|
}
|
||||||
|
# Merge in model_name/training_mode/world_size if stashed by train.py
|
||||||
|
if self._pending_run_info:
|
||||||
|
run_info.update(self._pending_run_info)
|
||||||
|
self._pending_run_info = None
|
||||||
|
self._put(run_info)
|
||||||
|
|
||||||
|
if not self._renderer_started_early:
|
||||||
|
# Attach a logging handler to feed log messages into the events panel
|
||||||
|
self._log_handler = _TUILogHandler(
|
||||||
|
self._queue, min_level=self._config.log_level
|
||||||
|
)
|
||||||
|
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||||
|
# Attach to both root and axolotl loggers (axolotl has propagate=False)
|
||||||
|
logging.getLogger().addHandler(self._log_handler)
|
||||||
|
logging.getLogger("axolotl").addHandler(self._log_handler)
|
||||||
|
|
||||||
|
# Start the renderer background thread
|
||||||
|
self._renderer.start()
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if logs is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Filter out non-numeric keys and internal keys
|
||||||
|
filtered = {}
|
||||||
|
for key, value in logs.items():
|
||||||
|
if key.startswith("_"):
|
||||||
|
continue
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
filtered[key] = value
|
||||||
|
elif isinstance(value, str):
|
||||||
|
# HF Trainer sometimes passes string-encoded numbers
|
||||||
|
try:
|
||||||
|
filtered[key] = float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if filtered:
|
||||||
|
self._put({"type": "metrics", "logs": filtered})
|
||||||
|
|
||||||
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
|
self._put(
|
||||||
|
{
|
||||||
|
"type": "step",
|
||||||
|
"step": state.global_step,
|
||||||
|
"total_steps": state.max_steps,
|
||||||
|
"epoch": state.epoch if state.epoch else 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_prediction_step(self, args, state, control, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
self._put({"type": "done"})
|
||||||
|
# If renderer was started early, do_train's finally block handles stop
|
||||||
|
if not self._renderer_started_early:
|
||||||
|
self._renderer.stop()
|
||||||
|
|
||||||
|
# Remove the logging handler (only if we added it)
|
||||||
|
if self._log_handler:
|
||||||
|
logging.getLogger().removeHandler(self._log_handler)
|
||||||
|
logging.getLogger("axolotl").removeHandler(self._log_handler)
|
||||||
|
self._log_handler = None
|
||||||
38
src/axolotl/tui/config.py
Normal file
38
src/axolotl/tui/config.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""TUI configuration — Pydantic model for TUI settings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class TUIConfig(BaseModel):
|
||||||
|
"""Configuration for the Axolotl Training TUI dashboard."""
|
||||||
|
|
||||||
|
enabled: bool = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={"description": "Enable the TUI dashboard"},
|
||||||
|
)
|
||||||
|
refresh_rate: int = Field(
|
||||||
|
default=4,
|
||||||
|
json_schema_extra={"description": "Renders per second"},
|
||||||
|
)
|
||||||
|
log_level: str = Field(
|
||||||
|
default="debug",
|
||||||
|
json_schema_extra={"description": "Minimum log level shown in events panel"},
|
||||||
|
)
|
||||||
|
panels: list[str] = Field(
|
||||||
|
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
|
||||||
|
json_schema_extra={"description": "Ordered list of panels to display"},
|
||||||
|
)
|
||||||
|
hardware_poll_interval: int = Field(
|
||||||
|
default=2,
|
||||||
|
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
|
||||||
|
)
|
||||||
|
stdout_log_path: str = Field(
|
||||||
|
default="axolotl_stdout.log",
|
||||||
|
json_schema_extra={"description": "File path for captured stdout/stderr log"},
|
||||||
|
)
|
||||||
|
parser_plugins: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
json_schema_extra={"description": "List of extra parser classes to load"},
|
||||||
|
)
|
||||||
72
src/axolotl/tui/gpu.py
Normal file
72
src/axolotl/tui/gpu.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""GPU polling wrapper around pynvml with graceful fallback."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from axolotl.tui.state import GPUStats
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_nvml_available = False
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
_nvml_available = True
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("pynvml unavailable — GPU stats will not be shown")
|
||||||
|
|
||||||
|
|
||||||
|
class GPUPoller:
|
||||||
|
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._device_count = 0
|
||||||
|
if _nvml_available:
|
||||||
|
try:
|
||||||
|
self._device_count = pynvml.nvmlDeviceGetCount()
|
||||||
|
except Exception:
|
||||||
|
self._device_count = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
return _nvml_available and self._device_count > 0
|
||||||
|
|
||||||
|
def poll(self) -> list[GPUStats]:
|
||||||
|
if not self.available:
|
||||||
|
return []
|
||||||
|
|
||||||
|
stats = []
|
||||||
|
for i in range(self._device_count):
|
||||||
|
try:
|
||||||
|
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||||
|
name = pynvml.nvmlDeviceGetName(handle)
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = name.decode("utf-8")
|
||||||
|
|
||||||
|
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||||
|
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||||
|
temp = pynvml.nvmlDeviceGetTemperature(
|
||||||
|
handle, pynvml.NVML_TEMPERATURE_GPU
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
||||||
|
except Exception:
|
||||||
|
power = None
|
||||||
|
|
||||||
|
stats.append(
|
||||||
|
GPUStats(
|
||||||
|
id=i,
|
||||||
|
name=name,
|
||||||
|
util_pct=util.gpu,
|
||||||
|
vram_used_gb=mem.used / (1024**3),
|
||||||
|
vram_total_gb=mem.total / (1024**3),
|
||||||
|
temp_c=temp,
|
||||||
|
power_w=power,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
LOG.debug("Error polling GPU device %d", i, exc_info=True)
|
||||||
|
return stats
|
||||||
196
src/axolotl/tui/io_capture.py
Normal file
196
src/axolotl/tui/io_capture.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import IO
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Parser registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_parser_registry: list[type[LineParser]] = []
|
||||||
|
|
||||||
|
|
||||||
|
def register_parser(cls: type[LineParser]) -> type[LineParser]:
|
||||||
|
"""Decorator to register a LineParser subclass."""
|
||||||
|
if cls not in _parser_registry:
|
||||||
|
_parser_registry.append(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def get_registered_parsers() -> list[type[LineParser]]:
|
||||||
|
return list(_parser_registry)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Base LineParser
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class LineParser(ABC):
|
||||||
|
"""Base class for stdout/stderr line parsers."""
|
||||||
|
|
||||||
|
priority: int = 50
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
"""Parse a single captured line.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: one line of captured output, trailing newline stripped.
|
||||||
|
source: "stdout" or "stderr".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of event dicts to push onto the metric queue.
|
||||||
|
Return [] if this line is not relevant.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ParserChain
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class ParserChain:
|
||||||
|
def __init__(self):
|
||||||
|
self._parsers: list[LineParser] = []
|
||||||
|
|
||||||
|
def register(self, parser: LineParser) -> None:
|
||||||
|
self._parsers.append(parser)
|
||||||
|
self._parsers.sort(key=lambda p: p.priority)
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str = "stdout") -> list[dict]:
|
||||||
|
events: list[dict] = []
|
||||||
|
for parser in self._parsers:
|
||||||
|
events.extend(parser.parse(line, source))
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# IOCapture — OS-level fd redirect to pipe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class IOCapture:
|
||||||
|
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
|
||||||
|
passes lines through a ParserChain, and tees to a log file."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
|
||||||
|
):
|
||||||
|
self._parser_chain = parser_chain
|
||||||
|
self._queue = metric_queue
|
||||||
|
self._log_path = log_path
|
||||||
|
self._log_file: IO[str] | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._read_fd: int | None = None
|
||||||
|
self._write_fd: int | None = None
|
||||||
|
self._saved_stdout_fd: int | None = None
|
||||||
|
self._saved_stderr_fd: int | None = None
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
# Write run-start separator
|
||||||
|
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
|
||||||
|
self._log_file.write(
|
||||||
|
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
|
||||||
|
)
|
||||||
|
self._log_file.flush()
|
||||||
|
|
||||||
|
# OS-level pipe
|
||||||
|
self._read_fd, self._write_fd = os.pipe()
|
||||||
|
|
||||||
|
# Save originals
|
||||||
|
self._saved_stdout_fd = os.dup(1)
|
||||||
|
self._saved_stderr_fd = os.dup(2)
|
||||||
|
|
||||||
|
# Redirect both stdout and stderr into the write end
|
||||||
|
os.dup2(self._write_fd, 1)
|
||||||
|
os.dup2(self._write_fd, 2)
|
||||||
|
os.close(self._write_fd) # write end now held by fds 1 and 2
|
||||||
|
|
||||||
|
# Also redirect Python-level handles
|
||||||
|
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
|
||||||
|
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
|
||||||
|
|
||||||
|
# Drain thread
|
||||||
|
self._thread = threading.Thread(target=self._drain, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
# Restore fds — closes the write end, causing reader to see EOF
|
||||||
|
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
sys.stderr = sys.__stderr__
|
||||||
|
os.dup2(self._saved_stdout_fd, 1)
|
||||||
|
os.dup2(self._saved_stderr_fd, 2)
|
||||||
|
os.close(self._saved_stdout_fd)
|
||||||
|
os.close(self._saved_stderr_fd)
|
||||||
|
self._saved_stdout_fd = None
|
||||||
|
self._saved_stderr_fd = None
|
||||||
|
|
||||||
|
if self._thread is not None:
|
||||||
|
self._thread.join(timeout=2.0)
|
||||||
|
if self._thread.is_alive():
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"IO capture thread did not exit after 2s"
|
||||||
|
)
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
if self._log_file is not None:
|
||||||
|
self._log_file.close()
|
||||||
|
self._log_file = None
|
||||||
|
|
||||||
|
def _drain(self) -> None:
|
||||||
|
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
|
||||||
|
# which use \r for in-place updates without \n
|
||||||
|
assert self._read_fd is not None, "_drain called before start()"
|
||||||
|
with os.fdopen(self._read_fd, "rb") as pipe:
|
||||||
|
buf = b""
|
||||||
|
while True:
|
||||||
|
chunk = pipe.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
# EOF — process remaining buffer
|
||||||
|
if buf:
|
||||||
|
self._process_line(buf.decode("utf-8", errors="replace"))
|
||||||
|
break
|
||||||
|
buf += chunk
|
||||||
|
# Split on \n or \r
|
||||||
|
while b"\n" in buf or b"\r" in buf:
|
||||||
|
# Find the earliest delimiter
|
||||||
|
idx_n = buf.find(b"\n")
|
||||||
|
idx_r = buf.find(b"\r")
|
||||||
|
if idx_n == -1:
|
||||||
|
idx = idx_r
|
||||||
|
elif idx_r == -1:
|
||||||
|
idx = idx_n
|
||||||
|
else:
|
||||||
|
idx = min(idx_n, idx_r)
|
||||||
|
line = buf[:idx].decode("utf-8", errors="replace")
|
||||||
|
buf = buf[idx + 1 :]
|
||||||
|
# Handle \r\n as single delimiter
|
||||||
|
if buf.startswith(b"\n"):
|
||||||
|
buf = buf[1:]
|
||||||
|
if line:
|
||||||
|
self._process_line(line)
|
||||||
|
|
||||||
|
def _process_line(self, line: str) -> None:
|
||||||
|
line = line.rstrip()
|
||||||
|
if not line:
|
||||||
|
return
|
||||||
|
if self._log_file:
|
||||||
|
self._log_file.write(line + "\n")
|
||||||
|
self._log_file.flush()
|
||||||
|
for event in self._parser_chain.parse(line):
|
||||||
|
try:
|
||||||
|
self._queue.put_nowait(event)
|
||||||
|
except queue.Full:
|
||||||
|
pass
|
||||||
63
src/axolotl/tui/panels/__init__.py
Normal file
63
src/axolotl/tui/panels/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Panel registry and base class for TUI panels."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Panel registry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_panel_registry: dict[str, type[BasePanel]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_panel(position: str = "bottom", weight: int = 50):
|
||||||
|
"""Decorator to register a panel class with position and weight."""
|
||||||
|
|
||||||
|
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
|
||||||
|
cls.position = position
|
||||||
|
cls.weight = weight
|
||||||
|
_panel_registry[cls.name] = cls
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def get_registered_panels() -> dict[str, type[BasePanel]]:
|
||||||
|
return dict(_panel_registry)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BasePanel
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class BasePanel(ABC):
|
||||||
|
name: str = ""
|
||||||
|
position: str = "bottom"
|
||||||
|
weight: int = 50
|
||||||
|
min_height: int = 4
|
||||||
|
max_height: int | None = None
|
||||||
|
modes: list[str] = ["*"]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
"""Return a rich renderable. Called every tick."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_event(self, event: dict) -> None: # noqa: B027
|
||||||
|
"""Optional: react to raw metric events before state is merged."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Auto-import built-in panels to trigger registration
|
||||||
|
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
|
||||||
|
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
|
||||||
|
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
|
||||||
|
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
|
||||||
|
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
|
||||||
|
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401
|
||||||
61
src/axolotl/tui/panels/completions.py
Normal file
61
src/axolotl/tui/panels/completions.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""CompletionsPanel — shows recent RL/log_completions samples."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(s: str, maxlen: int = 60) -> str:
|
||||||
|
return s[:maxlen] + "…" if len(s) > maxlen else s
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="bottom", weight=20)
|
||||||
|
class CompletionsPanel(BasePanel):
|
||||||
|
name = "completions"
|
||||||
|
min_height = 6
|
||||||
|
modes = ["grpo", "dpo"]
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
if "*" not in self.modes and state.training_mode not in self.modes:
|
||||||
|
return Text("")
|
||||||
|
|
||||||
|
if not state.completions:
|
||||||
|
return Panel(
|
||||||
|
Text("No completions yet...", style="dim"),
|
||||||
|
title="Completions",
|
||||||
|
border_style="magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Table(
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold",
|
||||||
|
expand=True,
|
||||||
|
box=None,
|
||||||
|
pad_edge=False,
|
||||||
|
)
|
||||||
|
table.add_column("step", justify="right", width=6)
|
||||||
|
table.add_column("prompt", no_wrap=False, max_width=40)
|
||||||
|
table.add_column("completion", no_wrap=False, max_width=40)
|
||||||
|
table.add_column("reward", justify="right", width=8)
|
||||||
|
table.add_column("adv", justify="right", width=8)
|
||||||
|
|
||||||
|
for sample in list(state.completions)[-5:]:
|
||||||
|
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
|
||||||
|
adv_str = (
|
||||||
|
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
|
||||||
|
)
|
||||||
|
table.add_row(
|
||||||
|
str(sample.step),
|
||||||
|
_truncate(sample.prompt),
|
||||||
|
_truncate(sample.completion),
|
||||||
|
reward_str,
|
||||||
|
adv_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Panel(table, title="Completions", border_style="magenta")
|
||||||
34
src/axolotl/tui/panels/debug.py
Normal file
34
src/axolotl/tui/panels/debug.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="bottom", weight=30)
|
||||||
|
class DebugPanel(BasePanel):
|
||||||
|
name = "debug"
|
||||||
|
min_height = 6
|
||||||
|
max_height = 10
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
lines = Text()
|
||||||
|
# Show last 8 debug-level log lines
|
||||||
|
debug_lines = [
|
||||||
|
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
|
||||||
|
][-8:]
|
||||||
|
for log_line in debug_lines:
|
||||||
|
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||||
|
lines.append(f"[{ts}] ", style="dim")
|
||||||
|
lines.append(log_line.message[:200], style="dim")
|
||||||
|
lines.append("\n")
|
||||||
|
|
||||||
|
if not debug_lines:
|
||||||
|
lines = Text("No debug messages yet...", style="dim")
|
||||||
|
|
||||||
|
return Panel(lines, title="Debug", border_style="dim")
|
||||||
45
src/axolotl/tui/panels/events.py
Normal file
45
src/axolotl/tui/panels/events.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""EventsPanel — scrolling log of recent events, color-coded by level."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
_LEVEL_STYLES = {
|
||||||
|
"debug": "dim",
|
||||||
|
"info": "",
|
||||||
|
"warning": "yellow",
|
||||||
|
"error": "red bold",
|
||||||
|
"critical": "red bold",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="bottom", weight=10)
|
||||||
|
class EventsPanel(BasePanel):
|
||||||
|
name = "events"
|
||||||
|
min_height = 8
|
||||||
|
max_height = 20
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
lines = Text()
|
||||||
|
# Show last 15 non-debug log lines (debug goes to DebugPanel)
|
||||||
|
recent = [
|
||||||
|
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
|
||||||
|
][-15:]
|
||||||
|
for log_line in recent:
|
||||||
|
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||||
|
level = log_line.level.upper()
|
||||||
|
style = _LEVEL_STYLES.get(log_line.level, "")
|
||||||
|
lines.append(f"[{ts}] ", style="dim")
|
||||||
|
lines.append(f"[{level}] ", style=style or "")
|
||||||
|
lines.append(log_line.message[:200], style=style or "")
|
||||||
|
lines.append("\n")
|
||||||
|
|
||||||
|
if not recent:
|
||||||
|
lines = Text("No events yet...", style="dim")
|
||||||
|
|
||||||
|
return Panel(lines, title="Events", border_style="yellow")
|
||||||
80
src/axolotl/tui/panels/hardware.py
Normal file
80
src/axolotl/tui/panels/hardware.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""HardwarePanel — per-GPU stats via pynvml."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
_BAR_FULL = "█"
|
||||||
|
_BAR_EMPTY = "░"
|
||||||
|
|
||||||
|
|
||||||
|
def _util_bar(pct: float, width: int = 6) -> Text:
|
||||||
|
filled = int(pct / 100 * width)
|
||||||
|
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
|
||||||
|
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
|
||||||
|
return Text.assemble((bar, color), f" {pct:3.0f}%")
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="right", weight=10)
|
||||||
|
class HardwarePanel(BasePanel):
|
||||||
|
name = "hardware"
|
||||||
|
min_height = 6
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
if not state.gpus:
|
||||||
|
return Panel(
|
||||||
|
Text("GPU stats unavailable", style="dim"),
|
||||||
|
title="Hardware",
|
||||||
|
border_style="green",
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Table(
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold",
|
||||||
|
expand=True,
|
||||||
|
box=None,
|
||||||
|
pad_edge=False,
|
||||||
|
)
|
||||||
|
table.add_column("id", justify="right", width=3)
|
||||||
|
table.add_column("util", no_wrap=True)
|
||||||
|
table.add_column("vram", no_wrap=True)
|
||||||
|
table.add_column("°C", justify="right", width=4)
|
||||||
|
table.add_column("W", justify="right", width=5)
|
||||||
|
|
||||||
|
total_vram_used = 0.0
|
||||||
|
total_vram_total = 0.0
|
||||||
|
total_util = 0.0
|
||||||
|
|
||||||
|
for gpu in state.gpus:
|
||||||
|
total_vram_used += gpu.vram_used_gb
|
||||||
|
total_vram_total += gpu.vram_total_gb
|
||||||
|
total_util += gpu.util_pct
|
||||||
|
|
||||||
|
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
|
||||||
|
table.add_row(
|
||||||
|
str(gpu.id),
|
||||||
|
_util_bar(gpu.util_pct),
|
||||||
|
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
|
||||||
|
str(gpu.temp_c),
|
||||||
|
power_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Footer with aggregates
|
||||||
|
n = len(state.gpus)
|
||||||
|
if n > 1:
|
||||||
|
avg_util = total_util / n
|
||||||
|
table.add_row(
|
||||||
|
"Σ",
|
||||||
|
Text(f"avg {avg_util:.0f}%", style="dim"),
|
||||||
|
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Panel(table, title="Hardware", border_style="green")
|
||||||
73
src/axolotl/tui/panels/progress.py
Normal file
73
src/axolotl/tui/panels/progress.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.progress import BarColumn, Progress, TextColumn
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_time(seconds: float | None) -> str:
|
||||||
|
if seconds is None or seconds < 0:
|
||||||
|
return "--:--:--"
|
||||||
|
h = int(seconds) // 3600
|
||||||
|
m = (int(seconds) % 3600) // 60
|
||||||
|
s = int(seconds) % 60
|
||||||
|
return f"{h}:{m:02d}:{s:02d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_eta(seconds: float | None) -> str:
|
||||||
|
if seconds is None or seconds < 0:
|
||||||
|
return "eta --"
|
||||||
|
h = int(seconds) // 3600
|
||||||
|
m = (int(seconds) % 3600) // 60
|
||||||
|
if h > 0:
|
||||||
|
return f"eta {h}h{m:02d}m"
|
||||||
|
return f"eta {m}m{int(seconds) % 60:02d}s"
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="top", weight=10)
|
||||||
|
class ProgressPanel(BasePanel):
|
||||||
|
name = "progress"
|
||||||
|
min_height = 3
|
||||||
|
max_height = 3
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
pct = (
|
||||||
|
(state.current_step / state.total_steps * 100)
|
||||||
|
if state.total_steps > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Header line
|
||||||
|
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
|
||||||
|
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
|
||||||
|
header = Text.assemble(
|
||||||
|
("● ", "bold green"),
|
||||||
|
("AXOLOTL", "bold cyan"),
|
||||||
|
f" {mode_upper} · {model_short} ",
|
||||||
|
(
|
||||||
|
f"{state.current_step} / {state.total_steps}",
|
||||||
|
"bold",
|
||||||
|
),
|
||||||
|
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Progress bar
|
||||||
|
progress = Progress(
|
||||||
|
TextColumn(""),
|
||||||
|
BarColumn(bar_width=None),
|
||||||
|
TextColumn("{task.percentage:>3.0f}%"),
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
task = progress.add_task("", total=state.total_steps or 1)
|
||||||
|
progress.update(task, completed=state.current_step)
|
||||||
|
|
||||||
|
table = Table.grid(expand=True)
|
||||||
|
table.add_row(header)
|
||||||
|
table.add_row(progress)
|
||||||
|
return table
|
||||||
97
src/axolotl/tui/panels/training.py
Normal file
97
src/axolotl/tui/panels/training.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""TrainingPanel — live scalar metrics table with loss sparkline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from rich.console import RenderableType
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from axolotl.tui.panels import BasePanel, register_panel
|
||||||
|
from axolotl.tui.state import TUIState
|
||||||
|
|
||||||
|
# Braille sparkline characters (8 levels)
|
||||||
|
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
|
||||||
|
|
||||||
|
|
||||||
|
def _sparkline(values: list[float] | None, width: int = 20) -> str:
|
||||||
|
if not values or len(values) < 2:
|
||||||
|
return ""
|
||||||
|
vals = list(values)[-width:]
|
||||||
|
lo, hi = min(vals), max(vals)
|
||||||
|
rng = hi - lo if hi != lo else 1.0
|
||||||
|
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
|
||||||
|
|
||||||
|
|
||||||
|
# Known key ordering and formatting
|
||||||
|
_KNOWN_KEYS: list[tuple[str, str, str]] = [
|
||||||
|
("loss", "loss", ".4f"),
|
||||||
|
("grad_norm", "grad norm", ".3f"),
|
||||||
|
("learning_rate", "lr", ".2e"),
|
||||||
|
("tokens_per_second", "tok/s", ".1f"),
|
||||||
|
("samples_per_second", "samples/s", ".1f"),
|
||||||
|
("mfu", "MFU", ".1f"),
|
||||||
|
# RL-specific
|
||||||
|
("rewards_mean", "rewards/mean", ".4f"),
|
||||||
|
("rewards_std", "rewards/std", ".4f"),
|
||||||
|
("kl_divergence", "KL", ".4f"),
|
||||||
|
("clip_ratio", "clip ratio", ".3f"),
|
||||||
|
("queue_size", "queue", "d"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@register_panel(position="left", weight=10)
|
||||||
|
class TrainingPanel(BasePanel):
|
||||||
|
name = "training"
|
||||||
|
min_height = 8
|
||||||
|
|
||||||
|
def render(self, state: TUIState) -> RenderableType:
|
||||||
|
table = Table(
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold",
|
||||||
|
expand=True,
|
||||||
|
box=None,
|
||||||
|
pad_edge=False,
|
||||||
|
)
|
||||||
|
table.add_column("metric", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("value", justify="right")
|
||||||
|
table.add_column("trend", justify="left", no_wrap=True)
|
||||||
|
|
||||||
|
for attr, label, fmt in _KNOWN_KEYS:
|
||||||
|
val = getattr(state, attr, None)
|
||||||
|
if val is None:
|
||||||
|
# Also check extra dict
|
||||||
|
val = state.extra.get(attr)
|
||||||
|
if val is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
formatted = f"{val:{fmt}}"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
formatted = str(val)
|
||||||
|
|
||||||
|
trend = ""
|
||||||
|
if attr == "loss":
|
||||||
|
trend = _sparkline(list(state.loss_history))
|
||||||
|
|
||||||
|
table.add_row(label, formatted, trend)
|
||||||
|
|
||||||
|
# Any extra keys not in _KNOWN_KEYS
|
||||||
|
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
|
||||||
|
for key, val in sorted(state.extra.items()):
|
||||||
|
if key in known_attrs or val is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
formatted = f"{val:.4f}"
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
formatted = str(val)
|
||||||
|
table.add_row(key, formatted, "")
|
||||||
|
|
||||||
|
if table.row_count == 0:
|
||||||
|
return Panel(
|
||||||
|
Text("Waiting for first log step...", style="dim"),
|
||||||
|
title="Training",
|
||||||
|
border_style="blue",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Panel(table, title="Training", border_style="blue")
|
||||||
7
src/axolotl/tui/parsers/__init__.py
Normal file
7
src/axolotl/tui/parsers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
|
||||||
|
|
||||||
|
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
|
||||||
|
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
|
||||||
|
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
|
||||||
|
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
|
||||||
|
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
||||||
29
src/axolotl/tui/parsers/deepspeed.py
Normal file
29
src/axolotl/tui/parsers/deepspeed.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
|
||||||
|
|
||||||
|
@register_parser
|
||||||
|
class DeepSpeedParser(LineParser):
|
||||||
|
priority = 20
|
||||||
|
name = "deepspeed"
|
||||||
|
|
||||||
|
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
|
||||||
|
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
events: list[dict] = []
|
||||||
|
if m := self._SAMPLES_RE.search(line):
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"type": "metrics",
|
||||||
|
"logs": {"samples_per_second": float(m.group(1))},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if m := self._STAGE_RE.search(line):
|
||||||
|
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
|
||||||
|
return events
|
||||||
27
src/axolotl/tui/parsers/nccl.py
Normal file
27
src/axolotl/tui/parsers/nccl.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
|
||||||
|
|
||||||
|
@register_parser
|
||||||
|
class NCCLErrorParser(LineParser):
|
||||||
|
priority = 10
|
||||||
|
name = "nccl_error"
|
||||||
|
|
||||||
|
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
if self._RE.search(line):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "error",
|
||||||
|
"message": f"⚠ NCCL: {line}",
|
||||||
|
},
|
||||||
|
{"type": "alert", "severity": "error", "message": line},
|
||||||
|
]
|
||||||
|
return []
|
||||||
37
src/axolotl/tui/parsers/raw_log.py
Normal file
37
src/axolotl/tui/parsers/raw_log.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""RawLogParser — catches every line as a log_line event."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
|
||||||
|
|
||||||
|
@register_parser
|
||||||
|
class RawLogParser(LineParser):
|
||||||
|
priority = 99
|
||||||
|
name = "raw_log"
|
||||||
|
|
||||||
|
_LOG_RE = re.compile(
|
||||||
|
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
|
||||||
|
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
||||||
|
r"\s*[-]\s*(?P<msg>.+)$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out tqdm progress bar lines and other noisy output
|
||||||
|
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
|
||||||
|
_EMPTY_RE = re.compile(r"^\s*$")
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
# Skip empty lines and tqdm progress bar updates
|
||||||
|
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
|
||||||
|
return []
|
||||||
|
|
||||||
|
m = self._LOG_RE.match(line)
|
||||||
|
level = (
|
||||||
|
m.group("level").lower()
|
||||||
|
if m
|
||||||
|
else ("error" if source == "stderr" else "info")
|
||||||
|
)
|
||||||
|
return [{"type": "log_line", "level": level, "message": line}]
|
||||||
26
src/axolotl/tui/parsers/torch_compile.py
Normal file
26
src/axolotl/tui/parsers/torch_compile.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
|
||||||
|
|
||||||
|
@register_parser
|
||||||
|
class TorchCompileParser(LineParser):
|
||||||
|
priority = 20
|
||||||
|
name = "torch_compile"
|
||||||
|
|
||||||
|
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
if self._RE.search(line):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "warning",
|
||||||
|
"message": f"⚡ compile: {line}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return []
|
||||||
86
src/axolotl/tui/parsers/tqdm.py
Normal file
86
src/axolotl/tui/parsers/tqdm.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.tui.io_capture import LineParser, register_parser
|
||||||
|
|
||||||
|
|
||||||
|
@register_parser
|
||||||
|
class TqdmParser(LineParser):
|
||||||
|
priority = 15
|
||||||
|
name = "tqdm"
|
||||||
|
|
||||||
|
# Match tqdm-style progress lines, e.g.:
|
||||||
|
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
|
||||||
|
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
|
||||||
|
# 0%| | 0/30 [00:00<?, ?it/s]
|
||||||
|
_TQDM_RE = re.compile(
|
||||||
|
r"(?P<desc>.*?)\s*"
|
||||||
|
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
|
||||||
|
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
|
||||||
|
r"\s*\[(?P<elapsed>[^\]]*)\]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also match simpler forms like:
|
||||||
|
# Fetching 0 files: 0it [00:00, ?it/s]
|
||||||
|
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
|
||||||
|
|
||||||
|
def parse(self, line: str, source: str) -> list[dict]:
|
||||||
|
m = self._TQDM_RE.search(line)
|
||||||
|
if m:
|
||||||
|
desc = m.group("desc").strip().rstrip(":")
|
||||||
|
pct = int(m.group("pct"))
|
||||||
|
current = int(m.group("current").replace(",", ""))
|
||||||
|
total = int(m.group("total").replace(",", ""))
|
||||||
|
|
||||||
|
events: list[dict] = []
|
||||||
|
|
||||||
|
# Surface as a log line with progress info
|
||||||
|
if pct == 100 or pct == 0 or pct % 25 == 0:
|
||||||
|
msg = (
|
||||||
|
f"[{desc}] {pct}% ({current}/{total})"
|
||||||
|
if desc
|
||||||
|
else f"{pct}% ({current}/{total})"
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "info",
|
||||||
|
"message": msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also emit as a progress metric
|
||||||
|
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||||
|
if not cleaned_desc:
|
||||||
|
cleaned_desc = "progress"
|
||||||
|
events.append(
|
||||||
|
{
|
||||||
|
"type": "metrics",
|
||||||
|
"logs": {
|
||||||
|
f"progress/{cleaned_desc}": pct / 100.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
# Fallback: try simpler fetch-style progress lines
|
||||||
|
m = self._FETCH_RE.search(line)
|
||||||
|
if m:
|
||||||
|
desc = m.group("desc").strip().rstrip(":")
|
||||||
|
current = int(m.group("current"))
|
||||||
|
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||||
|
if not cleaned_desc:
|
||||||
|
cleaned_desc = "fetch"
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "info",
|
||||||
|
"message": f"[{desc}] {current}" if desc else f"{current}",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return []
|
||||||
449
src/axolotl/tui/renderer.py
Normal file
449
src/axolotl/tui/renderer.py
Normal file
@@ -0,0 +1,449 @@
|
|||||||
|
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.live import Live
|
||||||
|
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
|
from axolotl.tui.gpu import GPUPoller
|
||||||
|
from axolotl.tui.io_capture import (
|
||||||
|
IOCapture,
|
||||||
|
ParserChain,
|
||||||
|
get_registered_parsers,
|
||||||
|
)
|
||||||
|
from axolotl.tui.panels import BasePanel, get_registered_panels
|
||||||
|
from axolotl.tui.state import CompletionSample, LogLine, TUIState
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TUIRenderer:
|
||||||
|
"""Background thread that renders the TUI dashboard using rich.live.Live."""
|
||||||
|
|
||||||
|
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
|
||||||
|
self._config = config
|
||||||
|
self._queue = metric_queue
|
||||||
|
self._state = TUIState()
|
||||||
|
self._gpu_poller = GPUPoller()
|
||||||
|
self._panels: list[BasePanel] = []
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._io_capture: IOCapture | None = None
|
||||||
|
self._parser_chain: ParserChain | None = None
|
||||||
|
|
||||||
|
def _init_panels(self) -> None:
|
||||||
|
registry = get_registered_panels()
|
||||||
|
for panel_name in self._config.panels:
|
||||||
|
if panel_name in registry:
|
||||||
|
self._panels.append(registry[panel_name]())
|
||||||
|
|
||||||
|
def _init_parser_chain(self) -> None:
|
||||||
|
# Ensure built-in parsers are imported so @register_parser decorators fire
|
||||||
|
import axolotl.tui.parsers # noqa: F401
|
||||||
|
|
||||||
|
self._parser_chain = ParserChain()
|
||||||
|
# Register all built-in parsers
|
||||||
|
for parser_cls in get_registered_parsers():
|
||||||
|
self._parser_chain.register(parser_cls())
|
||||||
|
|
||||||
|
# Load plugin parsers
|
||||||
|
for plugin_spec in self._config.parser_plugins:
|
||||||
|
try:
|
||||||
|
if "::" in plugin_spec:
|
||||||
|
# file path :: class name
|
||||||
|
file_path, class_name = plugin_spec.split("::", 1)
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"custom_parser", file_path
|
||||||
|
)
|
||||||
|
if spec is None or spec.loader is None:
|
||||||
|
raise ImportError(f"Cannot load spec for {file_path}")
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
parser_cls = getattr(mod, class_name)
|
||||||
|
else:
|
||||||
|
# dotted module path
|
||||||
|
module_path, class_name = plugin_spec.rsplit(".", 1)
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
parser_cls = getattr(mod, class_name)
|
||||||
|
self._parser_chain.register(parser_cls())
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
|
||||||
|
|
||||||
|
def _build_layout(self) -> Layout:
|
||||||
|
layout = Layout()
|
||||||
|
|
||||||
|
top_panels = [p for p in self._panels if p.position == "top"]
|
||||||
|
left_panels = [p for p in self._panels if p.position == "left"]
|
||||||
|
right_panels = [p for p in self._panels if p.position == "right"]
|
||||||
|
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
||||||
|
|
||||||
|
sections = []
|
||||||
|
|
||||||
|
if top_panels:
|
||||||
|
layout_top = Layout(name="top", size=3)
|
||||||
|
sections.append(layout_top)
|
||||||
|
|
||||||
|
if left_panels or right_panels:
|
||||||
|
layout_middle = Layout(name="middle", ratio=3)
|
||||||
|
middle_parts = []
|
||||||
|
if left_panels:
|
||||||
|
middle_parts.append(Layout(name="left", ratio=1))
|
||||||
|
if right_panels:
|
||||||
|
middle_parts.append(Layout(name="right", ratio=1))
|
||||||
|
if middle_parts:
|
||||||
|
layout_middle.split_row(*middle_parts)
|
||||||
|
sections.append(layout_middle)
|
||||||
|
|
||||||
|
if bottom_panels:
|
||||||
|
layout_bottom = Layout(name="bottom", ratio=2)
|
||||||
|
if len(bottom_panels) > 1:
|
||||||
|
layout_bottom.split_row(
|
||||||
|
*[
|
||||||
|
Layout(name=f"bottom_{i}", ratio=1)
|
||||||
|
for i in range(len(bottom_panels))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
sections.append(layout_bottom)
|
||||||
|
|
||||||
|
if sections:
|
||||||
|
layout.split_column(*sections)
|
||||||
|
|
||||||
|
return layout
|
||||||
|
|
||||||
|
def _update_layout(self, layout: Layout) -> None:
|
||||||
|
top_panels = [p for p in self._panels if p.position == "top"]
|
||||||
|
left_panels = [p for p in self._panels if p.position == "left"]
|
||||||
|
right_panels = [p for p in self._panels if p.position == "right"]
|
||||||
|
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
||||||
|
|
||||||
|
if top_panels:
|
||||||
|
layout["top"].update(top_panels[0].render(self._state))
|
||||||
|
|
||||||
|
if left_panels:
|
||||||
|
layout["left"].update(left_panels[0].render(self._state))
|
||||||
|
|
||||||
|
if right_panels:
|
||||||
|
layout["right"].update(right_panels[0].render(self._state))
|
||||||
|
|
||||||
|
if bottom_panels:
|
||||||
|
if len(bottom_panels) == 1:
|
||||||
|
layout["bottom"].update(bottom_panels[0].render(self._state))
|
||||||
|
else:
|
||||||
|
for i, panel in enumerate(bottom_panels):
|
||||||
|
layout[f"bottom_{i}"].update(panel.render(self._state))
|
||||||
|
|
||||||
|
def _drain_queue(self) -> None:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
event = self._queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Dispatch event to panels first
|
||||||
|
for panel in self._panels:
|
||||||
|
panel.on_event(event)
|
||||||
|
|
||||||
|
event_type = event.get("type")
|
||||||
|
|
||||||
|
if event_type == "metrics":
|
||||||
|
logs = event.get("logs", {})
|
||||||
|
self._apply_metrics(logs)
|
||||||
|
|
||||||
|
elif event_type == "step":
|
||||||
|
self._state.current_step = event.get("step", self._state.current_step)
|
||||||
|
self._state.total_steps = event.get(
|
||||||
|
"total_steps", self._state.total_steps
|
||||||
|
)
|
||||||
|
self._state.current_epoch = event.get(
|
||||||
|
"epoch", self._state.current_epoch
|
||||||
|
)
|
||||||
|
now = time.time()
|
||||||
|
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
|
||||||
|
if self._state.current_step > 0 and self._state.total_steps > 0:
|
||||||
|
rate = self._state.elapsed_seconds / self._state.current_step
|
||||||
|
remaining = self._state.total_steps - self._state.current_step
|
||||||
|
self._state.eta_seconds = rate * remaining
|
||||||
|
|
||||||
|
elif event_type == "log_line":
|
||||||
|
level = event.get("level", "info")
|
||||||
|
message = event.get("message", "")
|
||||||
|
self._state.log_lines.append(
|
||||||
|
LogLine(
|
||||||
|
timestamp=datetime.now(),
|
||||||
|
level=level,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "completion":
|
||||||
|
self._state.completions.append(
|
||||||
|
CompletionSample(
|
||||||
|
step=event.get("step", 0),
|
||||||
|
prompt=event.get("prompt", ""),
|
||||||
|
completion=event.get("completion", ""),
|
||||||
|
reward=event.get("reward"),
|
||||||
|
advantage=event.get("advantage"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "run_info":
|
||||||
|
if "run_name" in event:
|
||||||
|
self._state.run_name = event["run_name"]
|
||||||
|
if "model_name" in event:
|
||||||
|
self._state.model_name = event["model_name"]
|
||||||
|
if "training_mode" in event:
|
||||||
|
self._state.training_mode = event["training_mode"]
|
||||||
|
if "world_size" in event:
|
||||||
|
self._state.world_size = event["world_size"]
|
||||||
|
if "total_steps" in event:
|
||||||
|
self._state.total_steps = event["total_steps"]
|
||||||
|
if "total_epochs" in event:
|
||||||
|
self._state.total_epochs = event["total_epochs"]
|
||||||
|
if "zero_stage" in event:
|
||||||
|
self._state.zero_stage = event["zero_stage"]
|
||||||
|
|
||||||
|
elif event_type == "done":
|
||||||
|
self._stop_event.set()
|
||||||
|
|
||||||
|
def _apply_metrics(self, logs: dict[str, Any]) -> None:
|
||||||
|
metric_map = {
|
||||||
|
"loss": "loss",
|
||||||
|
"grad_norm": "grad_norm",
|
||||||
|
"learning_rate": "learning_rate",
|
||||||
|
"tokens_per_second": "tokens_per_second",
|
||||||
|
"samples_per_second": "samples_per_second",
|
||||||
|
"mfu": "mfu",
|
||||||
|
"rewards/mean": "rewards_mean",
|
||||||
|
"rewards_mean": "rewards_mean",
|
||||||
|
"rewards/std": "rewards_std",
|
||||||
|
"rewards_std": "rewards_std",
|
||||||
|
"kl": "kl_divergence",
|
||||||
|
"kl_divergence": "kl_divergence",
|
||||||
|
"clip_ratio": "clip_ratio",
|
||||||
|
"queue_size": "queue_size",
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in logs.items():
|
||||||
|
if key in metric_map:
|
||||||
|
setattr(self._state, metric_map[key], value)
|
||||||
|
else:
|
||||||
|
self._state.extra[key] = value
|
||||||
|
|
||||||
|
if "loss" in logs and logs["loss"] is not None:
|
||||||
|
self._state.loss_history.append(logs["loss"])
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
self._init_panels()
|
||||||
|
self._init_parser_chain()
|
||||||
|
|
||||||
|
# Set up I/O capture
|
||||||
|
assert self._parser_chain is not None, "_init_parser_chain must be called first"
|
||||||
|
self._io_capture = IOCapture(
|
||||||
|
log_path=self._config.stdout_log_path,
|
||||||
|
parser_chain=self._parser_chain,
|
||||||
|
metric_queue=self._queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkeypatch tqdm to suppress terminal output and route through our queue.
|
||||||
|
# This prevents tqdm progress bars from flickering through the TUI and
|
||||||
|
# ensures all progress events appear in the Events panel.
|
||||||
|
self._install_tqdm_hook()
|
||||||
|
|
||||||
|
self._io_capture_ready = threading.Event()
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
self._io_capture_ready.wait(timeout=5.0)
|
||||||
|
|
||||||
|
def _install_tqdm_hook(self) -> None:
|
||||||
|
"""Replace tqdm's display method to route updates through TUI queue."""
|
||||||
|
try:
|
||||||
|
import io
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
import tqdm.auto
|
||||||
|
|
||||||
|
q = self._queue
|
||||||
|
self._tqdm_parser = None
|
||||||
|
# Find our tqdm parser in the chain
|
||||||
|
for p in self._parser_chain._parsers if self._parser_chain else []:
|
||||||
|
if p.name == "tqdm":
|
||||||
|
self._tqdm_parser = p
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save originals for restore
|
||||||
|
self._orig_tqdm_class_auto = tqdm.auto.tqdm
|
||||||
|
self._orig_tqdm_class_tqdm = tqdm.tqdm
|
||||||
|
self._orig_tqdm_class_std = tqdm.std.tqdm
|
||||||
|
|
||||||
|
class TUITqdm(tqdm.tqdm):
|
||||||
|
"""tqdm subclass that sends progress to TUI instead of terminal."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# Force output to devnull so nothing reaches the terminal
|
||||||
|
kwargs["file"] = io.StringIO()
|
||||||
|
kwargs["dynamic_ncols"] = False
|
||||||
|
kwargs["ncols"] = 80
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def display(self, msg=None, pos=None):
|
||||||
|
# Build a progress string and push to queue
|
||||||
|
if self.total and self.total > 0:
|
||||||
|
pct = self.n / self.total * 100
|
||||||
|
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||||
|
# Emit events at milestones or at low frequency
|
||||||
|
is_milestone = (
|
||||||
|
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
|
||||||
|
)
|
||||||
|
if is_milestone:
|
||||||
|
try:
|
||||||
|
q.put_nowait(
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "info",
|
||||||
|
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
|
||||||
|
if desc
|
||||||
|
else f"{pct:.0f}% ({self.n}/{self.total})",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
metric_key = (
|
||||||
|
f"progress/{desc.lower().replace(' ', '_')}"
|
||||||
|
if desc
|
||||||
|
else "progress/unknown"
|
||||||
|
)
|
||||||
|
q.put_nowait(
|
||||||
|
{
|
||||||
|
"type": "metrics",
|
||||||
|
"logs": {metric_key: pct / 100.0},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
# Emit final completion event
|
||||||
|
if self.total and self.total > 0 and self.n > 0:
|
||||||
|
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||||
|
try:
|
||||||
|
q.put_nowait(
|
||||||
|
{
|
||||||
|
"type": "log_line",
|
||||||
|
"level": "info",
|
||||||
|
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
|
||||||
|
if desc
|
||||||
|
else f"100% ({self.total}/{self.total}) done",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
super().close()
|
||||||
|
|
||||||
|
# Replace tqdm globally
|
||||||
|
tqdm.auto.tqdm = TUITqdm
|
||||||
|
tqdm.tqdm = TUITqdm
|
||||||
|
# Also patch tqdm.std which some libraries use directly
|
||||||
|
tqdm.std.tqdm = TUITqdm
|
||||||
|
self._tui_tqdm_cls = TUITqdm
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
LOG.debug(f"Failed to install tqdm hook: {exc}")
|
||||||
|
|
||||||
|
def _uninstall_tqdm_hook(self) -> None:
|
||||||
|
"""Restore original tqdm."""
|
||||||
|
try:
|
||||||
|
import tqdm
|
||||||
|
import tqdm.auto
|
||||||
|
|
||||||
|
if hasattr(self, "_orig_tqdm_class_auto"):
|
||||||
|
tqdm.auto.tqdm = self._orig_tqdm_class_auto
|
||||||
|
if hasattr(self, "_orig_tqdm_class_tqdm"):
|
||||||
|
tqdm.tqdm = self._orig_tqdm_class_tqdm
|
||||||
|
if hasattr(self, "_orig_tqdm_class_std"):
|
||||||
|
tqdm.std.tqdm = self._orig_tqdm_class_std
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._stop_event.set()
|
||||||
|
self._uninstall_tqdm_hook()
|
||||||
|
if self._thread is not None:
|
||||||
|
self._thread.join(timeout=5.0)
|
||||||
|
|
||||||
|
def _run(self) -> None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
|
||||||
|
# This ensures rich.live.Live writes to the terminal, not the pipe.
|
||||||
|
saved_tty_fd = os.dup(1)
|
||||||
|
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
|
||||||
|
console = Console(file=tty_file)
|
||||||
|
|
||||||
|
layout = self._build_layout()
|
||||||
|
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
|
||||||
|
gpu_poll_counter = 0
|
||||||
|
gpu_poll_ticks = max(
|
||||||
|
1, int(self._config.hardware_poll_interval / tick_interval)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
|
||||||
|
if self._io_capture:
|
||||||
|
self._io_capture.start()
|
||||||
|
|
||||||
|
# Signal that IO capture is live so start() can return
|
||||||
|
if hasattr(self, "_io_capture_ready"):
|
||||||
|
self._io_capture_ready.set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Live(
|
||||||
|
layout,
|
||||||
|
console=console,
|
||||||
|
refresh_per_second=self._config.refresh_rate,
|
||||||
|
screen=True,
|
||||||
|
redirect_stdout=False,
|
||||||
|
redirect_stderr=False,
|
||||||
|
) as live:
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
self._drain_queue()
|
||||||
|
|
||||||
|
# Poll GPU stats periodically
|
||||||
|
gpu_poll_counter += 1
|
||||||
|
if gpu_poll_counter >= gpu_poll_ticks:
|
||||||
|
gpu_poll_counter = 0
|
||||||
|
if self._gpu_poller.available:
|
||||||
|
self._state.gpus = self._gpu_poller.poll()
|
||||||
|
|
||||||
|
# Update elapsed time
|
||||||
|
self._state.elapsed_seconds = (
|
||||||
|
time.time() - self._state.start_time.timestamp()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._update_layout(layout)
|
||||||
|
live.update(layout)
|
||||||
|
|
||||||
|
time.sleep(tick_interval)
|
||||||
|
|
||||||
|
# Final drain
|
||||||
|
self._drain_queue()
|
||||||
|
self._update_layout(layout)
|
||||||
|
live.update(layout)
|
||||||
|
finally:
|
||||||
|
if self._io_capture:
|
||||||
|
self._io_capture.stop()
|
||||||
|
try:
|
||||||
|
tty_file.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
88
src/axolotl/tui/state.py
Normal file
88
src/axolotl/tui/state.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
"""TUI shared data model — dataclasses for the dashboard state."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPUStats:
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
util_pct: float
|
||||||
|
vram_used_gb: float
|
||||||
|
vram_total_gb: float
|
||||||
|
temp_c: int
|
||||||
|
power_w: float | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LogLine:
|
||||||
|
timestamp: datetime
|
||||||
|
level: str # "info" | "debug" | "warning" | "error"
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompletionSample:
|
||||||
|
step: int
|
||||||
|
prompt: str
|
||||||
|
completion: str
|
||||||
|
reward: float | None
|
||||||
|
advantage: float | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TUIState:
|
||||||
|
# Run metadata
|
||||||
|
run_name: str = ""
|
||||||
|
model_name: str = ""
|
||||||
|
training_mode: str = "sft"
|
||||||
|
world_size: int = 1
|
||||||
|
start_time: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
|
# Progress
|
||||||
|
current_step: int = 0
|
||||||
|
total_steps: int = 0
|
||||||
|
current_epoch: float = 0.0
|
||||||
|
total_epochs: float = 1.0
|
||||||
|
elapsed_seconds: float = 0.0
|
||||||
|
eta_seconds: float | None = None
|
||||||
|
|
||||||
|
# Training metrics (rolling window + current)
|
||||||
|
loss: float | None = None
|
||||||
|
grad_norm: float | None = None
|
||||||
|
learning_rate: float | None = None
|
||||||
|
tokens_per_second: float | None = None
|
||||||
|
samples_per_second: float | None = None
|
||||||
|
mfu: float | None = None
|
||||||
|
|
||||||
|
# RL-specific (None for non-RL modes)
|
||||||
|
rewards_mean: float | None = None
|
||||||
|
rewards_std: float | None = None
|
||||||
|
kl_divergence: float | None = None
|
||||||
|
clip_ratio: float | None = None
|
||||||
|
queue_size: int | None = None
|
||||||
|
|
||||||
|
# Per-GPU hardware (list indexed by local rank)
|
||||||
|
gpus: list[GPUStats] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Recent log lines
|
||||||
|
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
|
||||||
|
|
||||||
|
# Recent completions (GRPO/SFT with log_completions)
|
||||||
|
completions: deque[CompletionSample] = field(
|
||||||
|
default_factory=lambda: deque(maxlen=20)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss history for sparkline
|
||||||
|
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
|
||||||
|
|
||||||
|
# DeepSpeed zero stage (None if not using DeepSpeed)
|
||||||
|
zero_stage: int | None = None
|
||||||
|
|
||||||
|
# Arbitrary plugin state
|
||||||
|
extra: dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -13,6 +13,7 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.tui.config import TUIConfig
|
||||||
from axolotl.utils.datasets import get_default_process_count
|
from axolotl.utils.datasets import get_default_process_count
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.datasets import (
|
from axolotl.utils.schemas.datasets import (
|
||||||
@@ -140,6 +141,12 @@ class AxolotlInputConfig(
|
|||||||
vllm: VllmConfig | None = Field(
|
vllm: VllmConfig | None = Field(
|
||||||
default_factory=lambda: VllmConfig(),
|
default_factory=lambda: VllmConfig(),
|
||||||
)
|
)
|
||||||
|
tui: TUIConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "TUI dashboard configuration. Set enabled: true to activate."
|
||||||
|
},
|
||||||
|
)
|
||||||
qat: QATConfig | None = None
|
qat: QATConfig | None = None
|
||||||
quantization: PTQConfig | None = None
|
quantization: PTQConfig | None = None
|
||||||
reward_model: bool | None = Field(
|
reward_model: bool | None = Field(
|
||||||
@@ -703,6 +710,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
lora_embedding_kernel: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Apply custom LoRA autograd function for embedding layers. See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
chunked_cross_entropy: bool | None = Field(
|
chunked_cross_entropy: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -758,44 +771,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
# MoE aux-loss-free (AFB) toggles
|
|
||||||
moe_balance_type: Literal["gshard", "noaux_tc"] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "MoE load balancing strategy: 'gshard' for auxiliary loss, 'noaux_tc' for aux-loss-free bias updates affecting top-k selection only. Defaults to model's native behavior when unset.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_update_rate: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Per-step bias update rate (gamma). Recommended: 0.005–0.05. If unset, plugin default is 0.01.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_update_momentum: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "EMA momentum for expert load smoothing (0–1). If unset, plugin default is 0.9.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_bias_cap: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Absolute clamp for expert bias magnitude. If unset, plugin default is 2.0.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_afb_warmup_steps: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of initial steps to delay aux-free bias updates, allowing routing to stabilize. If unset, plugin default is 0.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
moe_bias_sync_group: Literal["world", "ep"] | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Reduction group for expert load counts: 'world' (DP) or 'ep' (expert-parallel group if available). Defaults to 'world' when unset.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -874,12 +849,6 @@ class AxolotlInputConfig(
|
|||||||
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
expert_parallel_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Number of processes participating in expert-parallel collectives. Set >1 to form EP groups for aux-free reductions; defaults to world when unset."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
special_tokens: SpecialTokensConfig | None = Field(
|
special_tokens: SpecialTokensConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1357,6 +1326,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data.get("lora_mlp_kernel")
|
data.get("lora_mlp_kernel")
|
||||||
or data.get("lora_qkv_kernel")
|
or data.get("lora_qkv_kernel")
|
||||||
or data.get("lora_o_kernel")
|
or data.get("lora_o_kernel")
|
||||||
|
or data.get("lora_embedding_kernel")
|
||||||
):
|
):
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_fsdp = data.get("fsdp_config") is not None
|
is_fsdp = data.get("fsdp_config") is not None
|
||||||
@@ -1404,7 +1374,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set, using unsloth optimizations, or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
||||||
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
kernel_fields = [
|
||||||
|
"lora_mlp_kernel",
|
||||||
|
"lora_qkv_kernel",
|
||||||
|
"lora_o_kernel",
|
||||||
|
"lora_embedding_kernel",
|
||||||
|
]
|
||||||
if (
|
if (
|
||||||
any(data.get(k) is not None for k in kernel_fields)
|
any(data.get(k) is not None for k in kernel_fields)
|
||||||
or any(data.get(k) for k in unsloth_fields)
|
or any(data.get(k) for k in unsloth_fields)
|
||||||
@@ -1417,10 +1392,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("trust_remote_code"):
|
if data.get("trust_remote_code"):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
|
|
||||||
if data.get("lora_dropout") != 0:
|
|
||||||
return data
|
|
||||||
|
|
||||||
# Check multi-GPU compatibility
|
# Check multi-GPU compatibility
|
||||||
capabilities = data.get("capabilities")
|
capabilities = data.get("capabilities")
|
||||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||||
@@ -1442,6 +1413,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data.get("lora_o_kernel") is None:
|
if data.get("lora_o_kernel") is None:
|
||||||
data["lora_o_kernel"] = True
|
data["lora_o_kernel"] = True
|
||||||
|
|
||||||
|
if data.get("lora_embedding_kernel") is None:
|
||||||
|
data["lora_embedding_kernel"] = True
|
||||||
|
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Auto-enabling LoRA kernel optimizations for faster training. "
|
"Auto-enabling LoRA kernel optimizations for faster training. "
|
||||||
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
|
||||||
|
|||||||
@@ -681,15 +681,7 @@ class LoRAValidationMixin:
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_lora_kernels_dora(cls, data):
|
def check_lora_kernels_dora(cls, data):
|
||||||
if (
|
# DoRA is now supported by lora kernels
|
||||||
data.get("lora_mlp_kernel")
|
|
||||||
or data.get("lora_qkv_kernel")
|
|
||||||
or data.get("lora_o_kernel")
|
|
||||||
) and data.get("peft_use_dora"):
|
|
||||||
raise ValueError(
|
|
||||||
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not "
|
|
||||||
"compatible with DoRA at the moment."
|
|
||||||
)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1386,14 +1378,6 @@ class ComplexValidationMixin:
|
|||||||
self.tensor_parallel_size = 1
|
self.tensor_parallel_size = 1
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_expert_parallel_size(self):
|
|
||||||
if not getattr(self, "expert_parallel_size", None):
|
|
||||||
self.expert_parallel_size = 1
|
|
||||||
elif self.expert_parallel_size < 1:
|
|
||||||
raise ValueError("expert_parallel_size must be >= 1")
|
|
||||||
return self
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_context_parallel_size(self):
|
def check_context_parallel_size(self):
|
||||||
if self.sequence_parallel_degree and not self.context_parallel_size:
|
if self.sequence_parallel_degree and not self.context_parallel_size:
|
||||||
|
|||||||
@@ -153,7 +153,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
|
|
||||||
proj.base_layer = base_layer
|
proj.base_layer = base_layer
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||||
# quant_state should be None since weight is bf16, not FP8
|
# quant_state should be None since weight is bf16, not FP8
|
||||||
self.assertIsNone(quant_state)
|
self.assertIsNone(quant_state)
|
||||||
|
|
||||||
@@ -174,7 +174,7 @@ class TestLoraFP8Guard(unittest.TestCase):
|
|||||||
scale_inv = torch.ones(1)
|
scale_inv = torch.ones(1)
|
||||||
base_layer.weight_scale_inv = scale_inv
|
base_layer.weight_scale_inv = scale_inv
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(proj)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(proj)
|
||||||
self.assertIs(quant_state, scale_inv)
|
self.assertIs(quant_state, scale_inv)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ def mock_proj():
|
|||||||
def test_get_lora_parameters(mock_proj):
|
def test_get_lora_parameters(mock_proj):
|
||||||
"""Tests get_lora_parameters function"""
|
"""Tests get_lora_parameters function"""
|
||||||
# Test with LoRA enabled
|
# Test with LoRA enabled
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
|
|
||||||
assert isinstance(W, torch.Tensor)
|
assert isinstance(W, torch.Tensor)
|
||||||
assert W.shape == (128, 64)
|
assert W.shape == (128, 64)
|
||||||
@@ -113,13 +113,13 @@ def test_get_lora_parameters(mock_proj):
|
|||||||
|
|
||||||
# Test with LoRA disabled
|
# Test with LoRA disabled
|
||||||
mock_proj.disable_adapters = True
|
mock_proj.disable_adapters = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
# Test with merged state
|
# Test with merged state
|
||||||
mock_proj.disable_adapters = False
|
mock_proj.disable_adapters = False
|
||||||
mock_proj.merged = True
|
mock_proj.merged = True
|
||||||
W, b, _, A, B, s = get_lora_parameters(mock_proj)
|
W, b, _, A, B, s, *_ = get_lora_parameters(mock_proj)
|
||||||
assert A is None and B is None and s is None
|
assert A is None and B is None and s is None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1245
tests/e2e/kernels/test_lora_features.py
Normal file
1245
tests/e2e/kernels/test_lora_features.py
Normal file
File diff suppressed because it is too large
Load Diff
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
120
tests/e2e/multigpu/test_fsdp2_lora_kernels.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Test LoRA kernels under FSDP2 multi-GPU training.
|
||||||
|
|
||||||
|
Verifies that lora_qkv_kernel, lora_o_kernel, lora_mlp_kernel, and
|
||||||
|
lora_embedding_kernel work correctly with FSDP2 sharding, including
|
||||||
|
with bias, dropout, and DoRA enabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from tests.e2e.utils import require_torch_2_7_0
|
||||||
|
|
||||||
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def _run_training(temp_dir, cfg):
|
||||||
|
"""Write config and launch multi-GPU training."""
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"axolotl",
|
||||||
|
"train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main-process-port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _base_lora_fsdp2_config(temp_dir, **overrides):
|
||||||
|
"""Base config for LoRA + FSDP2 + kernel tests."""
|
||||||
|
cfg = {
|
||||||
|
"base_model": "Qwen/Qwen3-0.6B",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:1%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 3,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"optimizer": "adamw_torch_fused",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": True,
|
||||||
|
"fsdp_version": 2,
|
||||||
|
"fsdp_config": {
|
||||||
|
"offload_params": False,
|
||||||
|
"cpu_ram_efficient_loading": False,
|
||||||
|
"transformer_layer_cls_to_wrap": "Qwen3DecoderLayer",
|
||||||
|
"state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
},
|
||||||
|
# Enable all LoRA kernels
|
||||||
|
"lora_mlp_kernel": True,
|
||||||
|
"lora_qkv_kernel": True,
|
||||||
|
"lora_o_kernel": True,
|
||||||
|
"lora_embedding_kernel": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
cfg.update(overrides)
|
||||||
|
return DictDefault(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFSDP2LoRAKernels:
|
||||||
|
"""Test LoRA kernels under FSDP2."""
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_basic(self, temp_dir):
|
||||||
|
"""Basic LoRA + kernels + FSDP2: no dropout, no bias, no DoRA."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dropout(self, temp_dir):
|
||||||
|
"""LoRA kernels + dropout + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir, lora_dropout=0.1)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dora(self, temp_dir):
|
||||||
|
"""LoRA kernels + DoRA + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(temp_dir, peft_use_dora=True)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@require_torch_2_7_0
|
||||||
|
def test_lora_kernels_with_dora_and_dropout(self, temp_dir):
|
||||||
|
"""LoRA kernels + DoRA + dropout + FSDP2."""
|
||||||
|
cfg = _base_lora_fsdp2_config(
|
||||||
|
temp_dir,
|
||||||
|
peft_use_dora=True,
|
||||||
|
lora_dropout=0.05,
|
||||||
|
)
|
||||||
|
_run_training(temp_dir, cfg)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
@@ -222,9 +222,9 @@ def test_model_specific_activation(model_name, expected_activation):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_patch_conditions():
|
def test_kernel_patch_conditions():
|
||||||
"""Test various conditions that should prevent kernel patching."""
|
"""Test that kernels ARE patched even with dropout and bias (now supported)."""
|
||||||
test_configs = [
|
test_configs = [
|
||||||
# Dropout prevents patching
|
# Dropout — kernels now support this
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -234,7 +234,7 @@ def test_kernel_patch_conditions():
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"bias": "none",
|
"bias": "none",
|
||||||
},
|
},
|
||||||
# Bias prevents patching
|
# Bias — kernels now support this
|
||||||
{
|
{
|
||||||
"peft_type": "LORA",
|
"peft_type": "LORA",
|
||||||
"task_type": "CAUSAL_LM",
|
"task_type": "CAUSAL_LM",
|
||||||
@@ -252,13 +252,14 @@ def test_kernel_patch_conditions():
|
|||||||
model = PeftModelForCausalLM(model, peft_config)
|
model = PeftModelForCausalLM(model, peft_config)
|
||||||
cfg = DictDefault({"lora_mlp_kernel": True})
|
cfg = DictDefault({"lora_mlp_kernel": True})
|
||||||
|
|
||||||
# Should not patch
|
|
||||||
patched_model = apply_lora_kernel_patches(model, cfg)
|
patched_model = apply_lora_kernel_patches(model, cfg)
|
||||||
layer = patched_model.model.model.layers[0].mlp
|
layer = patched_model.model.model.layers[0].mlp
|
||||||
|
|
||||||
# Verify no patches applied
|
# Verify patches ARE applied (dropout and bias are now supported)
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
|
assert (
|
||||||
assert layer.forward.__func__ is not apply_lora_mlp_geglu
|
layer.forward.__func__ is apply_lora_mlp_swiglu
|
||||||
|
or layer.forward.__func__ is apply_lora_mlp_geglu
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_config_options():
|
def test_kernel_config_options():
|
||||||
@@ -511,7 +512,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero should not patch."""
|
"""Test model loading with dropout non-zero DOES patch (now supported)."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -546,31 +547,18 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
|||||||
# Load config
|
# Load config
|
||||||
cfg = load_cfg(str(path))
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Get original attention class
|
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
|
||||||
|
|
||||||
# Store original state before patching
|
|
||||||
original_forward_method = attention_cls.forward
|
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
# We call modelloader as that's where the patches are applied
|
|
||||||
# despite the fact that we're not using it to load the model
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer)
|
model_loader = ModelLoader(cfg, tokenizer)
|
||||||
|
|
||||||
# Apply patch
|
# Apply patches — should succeed even with dropout > 0
|
||||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||||
|
|
||||||
# Verify patch was not applied
|
|
||||||
assert attention_cls.forward == original_forward_method
|
|
||||||
|
|
||||||
# Apply apply_lora_kernel_patches
|
|
||||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||||
|
|
||||||
# Verify patch was not applied
|
# Verify patches WERE applied (dropout is now supported by kernels)
|
||||||
layers = get_layers(model)
|
layers = get_layers(model)
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
for self_attn in find_self_attn_in_layer(layer):
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
assert not hasattr(self_attn, "apply_qkv")
|
assert hasattr(self_attn, "apply_qkv")
|
||||||
assert not hasattr(self_attn, "apply_o")
|
assert hasattr(self_attn, "apply_o")
|
||||||
|
|||||||
@@ -1,79 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E smoke tests for Aux-Loss-Free MoE routing via plugin
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestMoeAuxFree(unittest.TestCase):
|
|
||||||
"""Smoke tests to ensure aux-free plugin enables and runs on Mixtral tiny."""
|
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_mixtral_aux_free_smoke(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
|
||||||
"flash_attention": False,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"bf16": False,
|
|
||||||
"fp16": False,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 1e-5,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_steps": 0,
|
|
||||||
"eval_steps": 0,
|
|
||||||
"save_first_step": False,
|
|
||||||
# Aux-free plugin and toggles
|
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
|
||||||
],
|
|
||||||
"moe_balance_type": "noaux_tc",
|
|
||||||
"moe_update_rate": 0.01,
|
|
||||||
"moe_update_momentum": 0.9,
|
|
||||||
"moe_bias_cap": 2.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
# Inspect model modules for a patched MoE layer
|
|
||||||
patched = None
|
|
||||||
for m in model.modules():
|
|
||||||
if hasattr(m, "_afb_patched") and getattr(m, "_afb_patched") is True:
|
|
||||||
patched = m
|
|
||||||
break
|
|
||||||
assert patched is not None, "No MoE layer patched by aux-free plugin"
|
|
||||||
assert hasattr(patched, "_afb_bias") and patched._afb_bias.ndim == 1
|
|
||||||
assert hasattr(patched, "_afb_counts") and patched._afb_counts.ndim == 1
|
|
||||||
# ensure counts buffer got reset by callback (best effort)
|
|
||||||
assert torch.all(patched._afb_counts == 0)
|
|
||||||
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
"""
|
|
||||||
Parity test comparing aux-loss (gshard) vs aux-loss-free (noaux_tc) on Mixtral-tiny.
|
|
||||||
Checks that aux-free training loss does not degrade beyond a small tolerance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
def _last_logged_loss(trainer) -> float | None:
|
|
||||||
# Scan log_history for the most recent entry with a 'loss' key
|
|
||||||
for entry in reversed(trainer.state.log_history):
|
|
||||||
if isinstance(entry, dict) and "loss" in entry:
|
|
||||||
return float(entry["loss"])
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class TestMoeAuxParity(unittest.TestCase):
|
|
||||||
@with_temp_dir
|
|
||||||
def test_mixtral_auxfree_vs_auxloss_loss_parity(self, temp_dir):
|
|
||||||
base_cfg = {
|
|
||||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
|
||||||
"tokenizer_config": "LoneStriker/Mixtral-8x7B-v0.1-HF",
|
|
||||||
"flash_attention": False,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"bf16": False,
|
|
||||||
"fp16": False,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {},
|
|
||||||
"datasets": [
|
|
||||||
{"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"learning_rate": 1e-5,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 8,
|
|
||||||
"save_steps": 0,
|
|
||||||
"eval_steps": 0,
|
|
||||||
"save_first_step": False,
|
|
||||||
"seed": 42,
|
|
||||||
"logging_steps": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Baseline: aux-loss (gshard)
|
|
||||||
cfg0 = DictDefault(dict(base_cfg))
|
|
||||||
cfg0.output_dir = f"{temp_dir}/baseline"
|
|
||||||
cfg0 = validate_config(cfg0)
|
|
||||||
normalize_config(cfg0)
|
|
||||||
# baseline uses default aux-loss routing; no plugin registration
|
|
||||||
dataset_meta0 = load_datasets(cfg=cfg0)
|
|
||||||
model0, _, trainer0 = train(cfg=cfg0, dataset_meta=dataset_meta0)
|
|
||||||
loss0 = _last_logged_loss(trainer0)
|
|
||||||
assert loss0 is not None
|
|
||||||
|
|
||||||
# Aux-free: plugin + noaux_tc
|
|
||||||
cfg1 = DictDefault(dict(base_cfg))
|
|
||||||
cfg1.output_dir = f"{temp_dir}/auxfree"
|
|
||||||
cfg1.plugins = [
|
|
||||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
|
||||||
]
|
|
||||||
cfg1.moe_balance_type = "noaux_tc"
|
|
||||||
cfg1.moe_update_rate = 0.01
|
|
||||||
cfg1.moe_update_momentum = 0.9
|
|
||||||
cfg1.moe_bias_cap = 2.0
|
|
||||||
cfg1 = validate_config(cfg1)
|
|
||||||
normalize_config(cfg1)
|
|
||||||
prepare_plugins(cfg1)
|
|
||||||
dataset_meta1 = load_datasets(cfg=cfg1)
|
|
||||||
model1, _, trainer1 = train(cfg=cfg1, dataset_meta=dataset_meta1)
|
|
||||||
loss1 = _last_logged_loss(trainer1)
|
|
||||||
assert loss1 is not None
|
|
||||||
|
|
||||||
# Assert aux-free loss is within 10% of aux-loss baseline
|
|
||||||
assert loss1 <= 1.1 * loss0, f"aux-free loss {loss1} > 1.1 * baseline {loss0}"
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""
|
|
||||||
E2E smoke test for Aux-Loss-Free MoE routing on Qwen3-MoE tiny
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config, prepare_plugins
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
|
||||||
|
|
||||||
|
|
||||||
class TestQwen3MoeAuxFree(unittest.TestCase):
|
|
||||||
@with_temp_dir
|
|
||||||
def test_qwen3_moe_aux_free_smoke(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
|
||||||
"tokenizer_config": "trl-internal-testing/tiny-Qwen3MoeForCausalLM",
|
|
||||||
"flash_attention": False,
|
|
||||||
"sequence_len": 512,
|
|
||||||
"bf16": False,
|
|
||||||
"fp16": False,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"micro_batch_size": 2,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 1e-5,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_steps": 0,
|
|
||||||
"eval_steps": 0,
|
|
||||||
"save_first_step": False,
|
|
||||||
# Aux-free plugin and toggles
|
|
||||||
"plugins": [
|
|
||||||
"axolotl.integrations.aux_free_router.plugin.AuxFreeMoEPlugin",
|
|
||||||
],
|
|
||||||
"moe_balance_type": "noaux_tc",
|
|
||||||
"moe_update_rate": 0.01,
|
|
||||||
"moe_update_momentum": 0.9,
|
|
||||||
"moe_bias_cap": 2.0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
prepare_plugins(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
model, _, _ = train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
|
|
||||||
# check that at least one sparse MoE block has been patched
|
|
||||||
found = False
|
|
||||||
for m in model.modules():
|
|
||||||
if m.__class__.__name__.endswith("SparseMoeBlock") and hasattr(m, "_afb_patched"):
|
|
||||||
assert m._afb_patched is True
|
|
||||||
assert hasattr(m, "_afb_bias") and m._afb_bias.ndim == 1
|
|
||||||
assert hasattr(m, "_afb_counts") and m._afb_counts.ndim == 1
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
assert found, "No Qwen3-MoE sparse block patched by aux-free plugin"
|
|
||||||
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
@@ -12,10 +12,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
try:
|
from tbparse import SummaryReader
|
||||||
from tbparse import SummaryReader
|
|
||||||
except ImportError: # pragma: no cover - optional dependency
|
|
||||||
SummaryReader = None
|
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
@@ -182,14 +179,12 @@ def check_tensorboard(
|
|||||||
tag: str,
|
tag: str,
|
||||||
lt_val: float,
|
lt_val: float,
|
||||||
assertion_err: str,
|
assertion_err: str,
|
||||||
rtol: float = 0.02,
|
rtol: float = 0.05,
|
||||||
gt_zero: bool = True,
|
gt_zero: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
"""
|
"""
|
||||||
if SummaryReader is None:
|
|
||||||
raise unittest.SkipTest("tbparse is not installed; skipping tensorboard assertions")
|
|
||||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
|
|||||||
229
tests/kernels/test_rms_norm_gated.py
Normal file
229
tests/kernels/test_rms_norm_gated.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""
|
||||||
|
Correctness tests for fused RMSNorm + SiLU Gate kernel.
|
||||||
|
|
||||||
|
Tests against the eager Qwen3_5RMSNormGated implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
pytest.importorskip("triton", reason="triton required for fused kernels")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA required for fused kernel tests", allow_module_level=True)
|
||||||
|
|
||||||
|
from axolotl.kernels.rms_norm_gated import FusedRMSNormGated
|
||||||
|
|
||||||
|
|
||||||
|
class EagerRMSNormGated(torch.nn.Module):
|
||||||
|
"""Reference implementation matching Qwen3_5RMSNormGated exactly."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states, gate=None):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
hidden_states = self.weight * hidden_states.to(input_dtype)
|
||||||
|
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
|
||||||
|
return hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_weights(eager_mod, fused_mod):
|
||||||
|
"""Copy weights from eager to fused module."""
|
||||||
|
fused_mod.weight.data.copy_(eager_mod.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape",
|
||||||
|
[
|
||||||
|
(2, 128, 256),
|
||||||
|
(4, 64, 512),
|
||||||
|
(1, 32, 1024),
|
||||||
|
(2, 16, 2560), # Qwen3.5-4B hidden_size
|
||||||
|
(2, 16, 4096), # Qwen3.5-9B hidden_size
|
||||||
|
(1, 8, 5120), # Qwen3.5-27B hidden_size
|
||||||
|
(4, 16, 2048), # Qwen3.5-35B-A3B (MoE) hidden_size
|
||||||
|
(4, 16, 3072), # Qwen3.5-122B-A10B (MoE) hidden_size
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class TestRMSNormGatedForward:
|
||||||
|
def test_output_matches_eager(self, dtype, shape):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
B, T, H = shape
|
||||||
|
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||||
|
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X, gate=G)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
torch.testing.assert_close(y_fused, y_eager, atol=1e-5, rtol=1e-5)
|
||||||
|
else:
|
||||||
|
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_output_shape(self, dtype, shape):
|
||||||
|
B, T, H = shape
|
||||||
|
X = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||||
|
G = torch.randn(B, T, H, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
y = fused(X, gate=G)
|
||||||
|
assert y.shape == (B, T, H)
|
||||||
|
assert y.dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shape",
|
||||||
|
[
|
||||||
|
(2, 32, 256),
|
||||||
|
(2, 16, 512),
|
||||||
|
(2, 16, 2560), # Qwen3.5-4B
|
||||||
|
(1, 8, 4096), # Qwen3.5-9B
|
||||||
|
(1, 8, 5120), # Qwen3.5-27B
|
||||||
|
(2, 16, 2048), # Qwen3.5-35B-A3B (MoE)
|
||||||
|
(2, 16, 3072), # Qwen3.5-122B-A10B (MoE)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
class TestRMSNormGatedBackward:
|
||||||
|
def test_grad_x(self, dtype, shape):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
B, T, H = shape
|
||||||
|
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
X_ref = X.detach().clone().requires_grad_(True)
|
||||||
|
G_ref = G.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X_ref, gate=G_ref)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
|
||||||
|
grad_out = torch.randn_like(y_eager)
|
||||||
|
y_eager.backward(grad_out)
|
||||||
|
y_fused.backward(grad_out)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
atol, rtol = 1e-4, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
|
||||||
|
torch.testing.assert_close(X.grad, X_ref.grad, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def test_grad_gate(self, dtype, shape):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
B, T, H = shape
|
||||||
|
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
X_ref = X.detach().clone().requires_grad_(True)
|
||||||
|
G_ref = G.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X_ref, gate=G_ref)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
|
||||||
|
grad_out = torch.randn_like(y_eager)
|
||||||
|
y_eager.backward(grad_out)
|
||||||
|
y_fused.backward(grad_out)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
atol, rtol = 1e-4, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
|
||||||
|
torch.testing.assert_close(G.grad, G_ref.grad, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
def test_grad_weight(self, dtype, shape):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
B, T, H = shape
|
||||||
|
X = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
G = torch.randn(B, T, H, dtype=dtype, device="cuda", requires_grad=True)
|
||||||
|
X_ref = X.detach().clone().requires_grad_(True)
|
||||||
|
G_ref = G.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=dtype, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X_ref, gate=G_ref)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
|
||||||
|
grad_out = torch.randn_like(y_eager)
|
||||||
|
y_eager.backward(grad_out)
|
||||||
|
y_fused.backward(grad_out)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
atol, rtol = 1e-4, 1e-4
|
||||||
|
else:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
fused.weight.grad, eager.weight.grad, atol=atol, rtol=rtol
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRMSNormGatedEdgeCases:
|
||||||
|
def test_gate_none_raises(self):
|
||||||
|
fused = FusedRMSNormGated(256).cuda()
|
||||||
|
X = torch.randn(2, 4, 256, device="cuda")
|
||||||
|
with pytest.raises(ValueError, match="requires a gate tensor"):
|
||||||
|
fused(X, gate=None)
|
||||||
|
|
||||||
|
def test_2d_input(self):
|
||||||
|
"""Test with (BxT, H) shaped input instead of (B, T, H)."""
|
||||||
|
torch.manual_seed(42)
|
||||||
|
H = 512
|
||||||
|
X = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||||
|
G = torch.randn(64, H, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
||||||
|
X_ref = X.detach().clone().requires_grad_(True)
|
||||||
|
G_ref = G.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X_ref, gate=G_ref)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
|
||||||
|
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
grad_out = torch.randn_like(y_eager)
|
||||||
|
y_eager.backward(grad_out)
|
||||||
|
y_fused.backward(grad_out)
|
||||||
|
|
||||||
|
torch.testing.assert_close(X.grad, X_ref.grad, atol=5e-2, rtol=5e-2)
|
||||||
|
torch.testing.assert_close(G.grad, G_ref.grad, atol=5e-2, rtol=5e-2)
|
||||||
|
|
||||||
|
def test_random_weight_init(self):
|
||||||
|
"""Test with non-default weight values."""
|
||||||
|
torch.manual_seed(123)
|
||||||
|
H = 256
|
||||||
|
X = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||||
|
G = torch.randn(2, 16, H, dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
eager = EagerRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
# Randomize weights
|
||||||
|
eager.weight.data = torch.randn_like(eager.weight.data)
|
||||||
|
|
||||||
|
fused = FusedRMSNormGated(H).to(dtype=torch.bfloat16, device="cuda")
|
||||||
|
_sync_weights(eager, fused)
|
||||||
|
|
||||||
|
y_eager = eager(X, gate=G)
|
||||||
|
y_fused = fused(X, gate=G)
|
||||||
|
torch.testing.assert_close(y_fused, y_eager, atol=1e-2, rtol=1e-2)
|
||||||
@@ -1,267 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import unittest
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
from importlib import util as importlib_util
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from axolotl.integrations.aux_free_router.plugin import AuxFreeMoEPlugin
|
|
||||||
|
|
||||||
|
|
||||||
def _cfg(**overrides):
|
|
||||||
defaults = dict(
|
|
||||||
moe_balance_type="noaux_tc",
|
|
||||||
moe_update_rate=0.1,
|
|
||||||
moe_update_momentum=0.9,
|
|
||||||
moe_bias_cap=2.0,
|
|
||||||
moe_afb_warmup_steps=0,
|
|
||||||
moe_bias_sync_group="world",
|
|
||||||
expert_parallel_size=1,
|
|
||||||
)
|
|
||||||
defaults.update(overrides)
|
|
||||||
return SimpleNamespace(**defaults)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_bailing_modules():
|
|
||||||
repo_dir = snapshot_download(
|
|
||||||
repo_id="inclusionAI/Ring-mini-2.0",
|
|
||||||
allow_patterns=[
|
|
||||||
"configuration_bailing_moe_v2.py",
|
|
||||||
"modeling_bailing_moe_v2.py",
|
|
||||||
"__init__.py",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
repo = Path(repo_dir)
|
|
||||||
config_path = repo / "configuration_bailing_moe_v2.py"
|
|
||||||
modeling_path = repo / "modeling_bailing_moe_v2.py"
|
|
||||||
|
|
||||||
config_name = "bailing_moe_v2.configuration_bailing_moe_v2"
|
|
||||||
if config_name not in sys.modules:
|
|
||||||
spec = importlib_util.spec_from_file_location(config_name, config_path)
|
|
||||||
module = importlib_util.module_from_spec(spec)
|
|
||||||
sys.modules[config_name] = module
|
|
||||||
sys.modules["configuration_bailing_moe_v2"] = module
|
|
||||||
assert spec.loader is not None
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
config_module = sys.modules[config_name]
|
|
||||||
|
|
||||||
modeling_name = "bailing_moe_v2.modeling_bailing_moe_v2"
|
|
||||||
if modeling_name not in sys.modules:
|
|
||||||
spec = importlib_util.spec_from_file_location(modeling_name, modeling_path)
|
|
||||||
module = importlib_util.module_from_spec(spec)
|
|
||||||
sys.modules[modeling_name] = module
|
|
||||||
sys.modules["modeling_bailing_moe_v2"] = module
|
|
||||||
assert spec.loader is not None
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
modeling_module = sys.modules[modeling_name]
|
|
||||||
|
|
||||||
BailingMoeV2Config = config_module.BailingMoeV2Config
|
|
||||||
BailingMoeV2SparseMoeBlock = modeling_module.BailingMoeV2SparseMoeBlock
|
|
||||||
|
|
||||||
return BailingMoeV2Config, BailingMoeV2SparseMoeBlock
|
|
||||||
|
|
||||||
|
|
||||||
def _build_bailing_model():
|
|
||||||
BailingConfig, BailingBlock = _load_bailing_modules()
|
|
||||||
config = BailingConfig(
|
|
||||||
hidden_size=16,
|
|
||||||
intermediate_size=32,
|
|
||||||
moe_intermediate_size=32,
|
|
||||||
num_experts=4,
|
|
||||||
num_shared_experts=None,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
n_group=1,
|
|
||||||
topk_group=1,
|
|
||||||
routed_scaling_factor=1.0,
|
|
||||||
)
|
|
||||||
block = BailingBlock(config)
|
|
||||||
|
|
||||||
class DummyModel(nn.Module):
|
|
||||||
def __init__(self, layer):
|
|
||||||
super().__init__()
|
|
||||||
self.block = layer
|
|
||||||
self.config = SimpleNamespace(model_type="bailing_moe")
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return self.block(hidden_states)
|
|
||||||
|
|
||||||
return DummyModel(block), block
|
|
||||||
|
|
||||||
|
|
||||||
def _build_llama4_model():
|
|
||||||
from transformers import Llama4TextConfig
|
|
||||||
from transformers.models.llama4.modeling_llama4 import Llama4TextMoe
|
|
||||||
|
|
||||||
config = Llama4TextConfig(
|
|
||||||
hidden_size=16,
|
|
||||||
intermediate_size=32,
|
|
||||||
num_local_experts=4,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
)
|
|
||||||
layer = Llama4TextMoe(config)
|
|
||||||
|
|
||||||
class DummyModel(nn.Module):
|
|
||||||
def __init__(self, moe_layer):
|
|
||||||
super().__init__()
|
|
||||||
self.moe = moe_layer
|
|
||||||
self.config = SimpleNamespace(model_type="llama4")
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return self.moe(hidden_states)
|
|
||||||
|
|
||||||
return DummyModel(layer), layer
|
|
||||||
|
|
||||||
|
|
||||||
def _build_mixtral_model():
|
|
||||||
from transformers import MixtralConfig
|
|
||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
|
||||||
|
|
||||||
config = MixtralConfig(
|
|
||||||
hidden_size=16,
|
|
||||||
intermediate_size=32,
|
|
||||||
num_local_experts=4,
|
|
||||||
num_experts_per_tok=2,
|
|
||||||
num_attention_heads=2,
|
|
||||||
num_key_value_heads=2,
|
|
||||||
)
|
|
||||||
layer = MixtralSparseMoeBlock(config)
|
|
||||||
layer.config = config
|
|
||||||
|
|
||||||
class DummyModel(nn.Module):
|
|
||||||
def __init__(self, moe_layer):
|
|
||||||
super().__init__()
|
|
||||||
self.moe = moe_layer
|
|
||||||
self.config = SimpleNamespace(model_type="mixtral")
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return self.moe(hidden_states)
|
|
||||||
|
|
||||||
return DummyModel(layer), layer
|
|
||||||
|
|
||||||
|
|
||||||
def _run_callback(plugin, cfg):
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
||||||
assert callbacks, "expected aux-free callback to be registered"
|
|
||||||
callback = callbacks[0]
|
|
||||||
dummy = SimpleNamespace()
|
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuxFreeAdapters(unittest.TestCase):
|
|
||||||
def test_bailing_adapter_updates_counts_and_bias(self):
|
|
||||||
model, block = _build_bailing_model()
|
|
||||||
cfg = _cfg()
|
|
||||||
plugin = AuxFreeMoEPlugin()
|
|
||||||
plugin.post_model_build(cfg, model)
|
|
||||||
|
|
||||||
self.assertTrue(hasattr(block, "_afb_bias"))
|
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
||||||
block(hidden)
|
|
||||||
self.assertGreater(torch.count_nonzero(block._afb_counts), 0)
|
|
||||||
|
|
||||||
_run_callback(plugin, cfg)
|
|
||||||
self.assertEqual(torch.count_nonzero(block._afb_counts), 0)
|
|
||||||
self.assertFalse(torch.allclose(block._afb_ema, torch.zeros_like(block._afb_ema)))
|
|
||||||
|
|
||||||
def test_llama4_adapter_biases_router_selection(self):
|
|
||||||
model, layer = _build_llama4_model()
|
|
||||||
cfg = _cfg()
|
|
||||||
plugin = AuxFreeMoEPlugin()
|
|
||||||
plugin.post_model_build(cfg, model)
|
|
||||||
|
|
||||||
self.assertTrue(hasattr(layer, "_afb_bias"))
|
|
||||||
hidden = torch.randn(2, 4, layer.hidden_dim)
|
|
||||||
layer(hidden)
|
|
||||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
||||||
|
|
||||||
_run_callback(plugin, cfg)
|
|
||||||
self.assertEqual(torch.count_nonzero(layer._afb_counts), 0)
|
|
||||||
self.assertFalse(torch.allclose(layer._afb_ema, torch.zeros_like(layer._afb_ema)))
|
|
||||||
|
|
||||||
def test_bias_warmup_respected(self):
|
|
||||||
model, block = _build_bailing_model()
|
|
||||||
cfg = _cfg(moe_afb_warmup_steps=2)
|
|
||||||
plugin = AuxFreeMoEPlugin()
|
|
||||||
plugin.post_model_build(cfg, model)
|
|
||||||
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
||||||
self.assertTrue(callbacks)
|
|
||||||
callback = callbacks[0]
|
|
||||||
dummy = SimpleNamespace()
|
|
||||||
|
|
||||||
def _step():
|
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
||||||
block(hidden)
|
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
||||||
|
|
||||||
# Warmup steps should leave bias untouched.
|
|
||||||
_step()
|
|
||||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
|
||||||
|
|
||||||
_step()
|
|
||||||
self.assertTrue(torch.allclose(block._afb_bias, torch.zeros_like(block._afb_bias)))
|
|
||||||
|
|
||||||
# Third step exceeds warmup -> bias should update.
|
|
||||||
_step()
|
|
||||||
self.assertGreater(torch.count_nonzero(block._afb_bias), 0)
|
|
||||||
|
|
||||||
def test_mixtral_adapter_respects_native_forward(self):
|
|
||||||
model, layer = _build_mixtral_model()
|
|
||||||
layer.jitter_noise = 0.0 # avoid stochasticity for comparison
|
|
||||||
|
|
||||||
hidden_dim = layer.config.hidden_size
|
|
||||||
hidden = torch.randn(2, 3, hidden_dim)
|
|
||||||
baseline_out, baseline_logits = layer(hidden.clone())
|
|
||||||
|
|
||||||
cfg = _cfg()
|
|
||||||
plugin = AuxFreeMoEPlugin()
|
|
||||||
plugin.post_model_build(cfg, model)
|
|
||||||
|
|
||||||
patched_out, patched_logits = layer(hidden.clone())
|
|
||||||
self.assertTrue(torch.allclose(baseline_out, patched_out))
|
|
||||||
self.assertTrue(torch.allclose(baseline_logits, patched_logits))
|
|
||||||
self.assertGreater(torch.count_nonzero(layer._afb_counts), 0)
|
|
||||||
_run_callback(plugin, cfg)
|
|
||||||
|
|
||||||
def test_ep_group_resolution_deferred_until_dist_ready(self):
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
model, block = _build_bailing_model()
|
|
||||||
cfg = _cfg(moe_bias_sync_group="ep", expert_parallel_size=1)
|
|
||||||
plugin = AuxFreeMoEPlugin()
|
|
||||||
plugin.post_model_build(cfg, model)
|
|
||||||
|
|
||||||
self.assertIsNotNone(plugin._shim)
|
|
||||||
self.assertIsNone(plugin._shim.ep_group)
|
|
||||||
|
|
||||||
callbacks = plugin.add_callbacks_post_trainer(cfg, trainer=SimpleNamespace())
|
|
||||||
self.assertTrue(callbacks)
|
|
||||||
callback = callbacks[0]
|
|
||||||
dummy = SimpleNamespace()
|
|
||||||
|
|
||||||
tmp_init = tempfile.NamedTemporaryFile(delete=False)
|
|
||||||
tmp_init.close()
|
|
||||||
init_method = f"file://{tmp_init.name}"
|
|
||||||
dist.init_process_group(backend="gloo", init_method=init_method, world_size=1, rank=0)
|
|
||||||
try:
|
|
||||||
hidden = torch.randn(2, 3, block.config.hidden_size)
|
|
||||||
block(hidden)
|
|
||||||
callback.on_step_end(args=dummy, state=dummy, control=dummy)
|
|
||||||
self.assertIs(plugin._shim.ep_group, dist.group.WORLD)
|
|
||||||
finally:
|
|
||||||
dist.destroy_process_group()
|
|
||||||
os.unlink(tmp_init.name)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -28,20 +28,22 @@ class TestLoRAConfigValidation:
|
|||||||
result = validate_config(valid_config)
|
result = validate_config(valid_config)
|
||||||
assert result["adapter"] == "lora"
|
assert result["adapter"] == "lora"
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not compatible with DoRA"):
|
# DoRA is now compatible with lora kernels
|
||||||
invalid_config = DictDefault(
|
dora_kernel_config = DictDefault(
|
||||||
{
|
{
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_mlp_kernel": True,
|
"lora_mlp_kernel": True,
|
||||||
"peft_use_dora": True,
|
"peft_use_dora": True,
|
||||||
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
"datasets": [{"path": "dummy_dataset", "type": "alpaca"}],
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"learning_rate": 1e-5,
|
"learning_rate": 1e-5,
|
||||||
"base_model": "dummy_model",
|
"base_model": "dummy_model",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
validate_config(invalid_config)
|
result = validate_config(dora_kernel_config)
|
||||||
|
assert result["lora_mlp_kernel"] is True
|
||||||
|
assert result["peft_use_dora"] is True
|
||||||
|
|
||||||
def test_qlora_4bit_validation(self):
|
def test_qlora_4bit_validation(self):
|
||||||
"""Test QLoRA 4-bit configuration validation"""
|
"""Test QLoRA 4-bit configuration validation"""
|
||||||
|
|||||||
@@ -38,6 +38,11 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
mock_layer.lora_A["default"].weight = torch.randn(16, 256, dtype=self.dtype)
|
||||||
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
mock_layer.lora_B["default"].weight = torch.randn(512, 16, dtype=self.dtype)
|
||||||
|
mock_layer.lora_B["default"].bias = None
|
||||||
|
|
||||||
|
# Required by get_lora_parameters for dropout/DoRA extraction
|
||||||
|
mock_layer.lora_dropout = {}
|
||||||
|
mock_layer.lora_magnitude_vector = None
|
||||||
else:
|
else:
|
||||||
mock_layer.weight = base_layer.weight
|
mock_layer.weight = base_layer.weight
|
||||||
mock_layer.bias = base_layer.bias
|
mock_layer.bias = base_layer.bias
|
||||||
@@ -48,7 +53,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are disabled."""
|
"""Test that LoRA parameters are None when adapters are disabled."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, adapters_disabled=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -62,7 +67,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test that LoRA parameters are None when adapters are merged."""
|
"""Test that LoRA parameters are None when adapters are merged."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
layer = self.create_mock_lora_layer(has_adapters=True, merged=True)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -77,7 +82,7 @@ class TestLoRAParameterFreezing:
|
|||||||
"""Test parameter behavior when no adapters are present."""
|
"""Test parameter behavior when no adapters are present."""
|
||||||
layer = self.create_mock_lora_layer(has_adapters=False)
|
layer = self.create_mock_lora_layer(has_adapters=False)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Base parameters should be returned
|
# Base parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -94,7 +99,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# All parameters should be returned
|
# All parameters should be returned
|
||||||
assert W is not None
|
assert W is not None
|
||||||
@@ -110,7 +115,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
# Check shape consistency
|
# Check shape consistency
|
||||||
assert W.shape == (512, 256)
|
assert W.shape == (512, 256)
|
||||||
@@ -124,7 +129,7 @@ class TestLoRAParameterFreezing:
|
|||||||
has_adapters=True, adapters_disabled=False, merged=False
|
has_adapters=True, adapters_disabled=False, merged=False
|
||||||
)
|
)
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert W.dtype == self.dtype
|
assert W.dtype == self.dtype
|
||||||
assert b.dtype == self.dtype
|
assert b.dtype == self.dtype
|
||||||
@@ -138,7 +143,7 @@ class TestLoRAParameterFreezing:
|
|||||||
quant_state_mock = Mock()
|
quant_state_mock = Mock()
|
||||||
layer.base_layer.weight.quant_state = quant_state_mock
|
layer.base_layer.weight.quant_state = quant_state_mock
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert quant_state == quant_state_mock
|
assert quant_state == quant_state_mock
|
||||||
|
|
||||||
@@ -157,7 +162,7 @@ class TestLoRAParameterFreezing:
|
|||||||
|
|
||||||
layer.active_adapters = ["adapter2"]
|
layer.active_adapters = ["adapter2"]
|
||||||
|
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(layer)
|
||||||
|
|
||||||
assert s == 0.2
|
assert s == 0.2
|
||||||
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
assert torch.equal(A, layer.lora_A["adapter2"].weight)
|
||||||
@@ -192,13 +197,13 @@ class TestLoRAParameterFreezingIntegration:
|
|||||||
model = get_peft_model(base_model, lora_config)
|
model = get_peft_model(base_model, lora_config)
|
||||||
lora_layer = model.base_model.model.linear
|
lora_layer = model.base_model.model.linear
|
||||||
# Test with adapters enabled
|
# Test with adapters enabled
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||||
assert A is not None
|
assert A is not None
|
||||||
assert B is not None
|
assert B is not None
|
||||||
assert s is not None
|
assert s is not None
|
||||||
# Test with adapters disabled
|
# Test with adapters disabled
|
||||||
model.disable_adapter_layers()
|
model.disable_adapter_layers()
|
||||||
W, b, quant_state, A, B, s = get_lora_parameters(lora_layer)
|
W, b, quant_state, A, B, s, *_ = get_lora_parameters(lora_layer)
|
||||||
assert A is None
|
assert A is None
|
||||||
assert B is None
|
assert B is None
|
||||||
assert s is None
|
assert s is None
|
||||||
|
|||||||
Reference in New Issue
Block a user