Compare commits
2 Commits
scattermoe
...
textui
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db6af43f3b | ||
|
|
35d06c8087 |
@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
type=click.Path(exists=True, path_type=str),
|
||||
help="YAML config for sweeping hyperparameters",
|
||||
)
|
||||
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
@filter_none_kwargs
|
||||
@@ -101,6 +102,7 @@ def train(
|
||||
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
|
||||
cloud: str | None = None,
|
||||
sweep: str | None = None,
|
||||
tui: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -118,6 +120,10 @@ def train(
|
||||
# Extract launcher args from extra args (after --)
|
||||
launcher_args = ctx.args if ctx.args else []
|
||||
|
||||
# Handle --tui flag: set env var so subprocess workers pick it up
|
||||
if tui:
|
||||
os.environ["AXOLOTL_TUI"] = "1"
|
||||
|
||||
# Handle Ray launcher override
|
||||
_launcher = None if kwargs.get("use_ray") else launcher
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import queue
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||
if not dataset_meta:
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
# Start TUI early (before data loading) so it captures preprocessing events
|
||||
tui_renderer = None
|
||||
tui_queue: queue.Queue | None = None
|
||||
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
|
||||
if is_rank_0:
|
||||
from axolotl.train import _is_tui_enabled
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
if _is_tui_enabled(cfg):
|
||||
import queue as _queue
|
||||
|
||||
del model, tokenizer, trainer
|
||||
from axolotl.train import _get_tui_config
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.renderer import TUIRenderer
|
||||
|
||||
gc.collect()
|
||||
tui_config_dict = _get_tui_config(cfg)
|
||||
tui_config = (
|
||||
TUIConfig(**tui_config_dict)
|
||||
if isinstance(tui_config_dict, dict)
|
||||
else tui_config_dict
|
||||
)
|
||||
tui_queue = _queue.Queue(maxsize=4096)
|
||||
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
# Send initial run info
|
||||
model_name = cfg.base_model or ""
|
||||
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
try:
|
||||
tui_queue.put_nowait(
|
||||
{
|
||||
"type": "run_info",
|
||||
"model_name": model_name,
|
||||
"training_mode": training_mode,
|
||||
"world_size": world_size,
|
||||
}
|
||||
)
|
||||
except _queue.Full:
|
||||
pass
|
||||
|
||||
tui_renderer.start()
|
||||
|
||||
# Attach logging handler early
|
||||
import logging
|
||||
|
||||
from axolotl.tui.callback import _TUILogHandler
|
||||
|
||||
_early_log_handler = _TUILogHandler(
|
||||
tui_queue, min_level=tui_config.log_level
|
||||
)
|
||||
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||
# Attach to BOTH root and axolotl loggers because axolotl logger
|
||||
# has propagate=False so root handler never sees axolotl.* messages
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.addHandler(_early_log_handler)
|
||||
axolotl_logger = logging.getLogger("axolotl")
|
||||
axolotl_logger.addHandler(_early_log_handler)
|
||||
|
||||
# Stash refs on cfg so train() can reuse the renderer
|
||||
cfg._tui_renderer = tui_renderer
|
||||
cfg._tui_queue = tui_queue
|
||||
cfg._tui_early_log_handler = _early_log_handler
|
||||
|
||||
try:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||
if not dataset_meta:
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
del model, tokenizer, trainer
|
||||
|
||||
gc.collect()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
finally:
|
||||
# If the TUI renderer started early but train() didn't get to stop it
|
||||
# (e.g., error during data loading), clean up here
|
||||
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
|
||||
try:
|
||||
if tui_queue is not None:
|
||||
tui_queue.put_nowait({"type": "done"})
|
||||
except queue.Full:
|
||||
pass
|
||||
tui_renderer.stop()
|
||||
# Remove early log handler from both root and axolotl loggers
|
||||
if hasattr(cfg, "_tui_early_log_handler"):
|
||||
import logging
|
||||
|
||||
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
|
||||
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
|
||||
@@ -36,8 +36,6 @@ SPARSE_MOE_BLOCK = {
|
||||
"glm4v_moe": "Glm4vMoeTextMoE",
|
||||
# sigmoid -> topk routing (no group selection)
|
||||
"minimax_m2": "MiniMaxM2SparseMoeBlock",
|
||||
# sigmoid -> topk routing, non-gated experts (up_proj + down_proj, no gate_up_proj)
|
||||
"nemotron_h": "NemotronHMoE",
|
||||
# Models below need custom routing (not yet implemented):
|
||||
# "ernie4_5_moe": "Ernie4_5_MoeSparseMoeBlock", # softmax->topk, e_score_correction_bias between softmax and topk
|
||||
# "deepseek_v2": "DeepseekV2Moe", # softmax->topk, group_limited_greedy, different attr names (num_group)
|
||||
|
||||
@@ -168,9 +168,6 @@ def _unwrap_experts_lora(experts_module):
|
||||
-> base_layer: ParamWrapper(gate_up_proj)
|
||||
-> base_layer: OlmoeExperts (the real module)
|
||||
|
||||
For non-gated experts (e.g. NemotronH), the chain targets ``up_proj``
|
||||
instead of ``gate_up_proj``.
|
||||
|
||||
This function walks the chain, collects LoRA params keyed by
|
||||
``parameter_name``, and returns the base experts module.
|
||||
|
||||
@@ -179,7 +176,6 @@ def _unwrap_experts_lora(experts_module):
|
||||
|
||||
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
|
||||
A/B are already in scattermoe layout.
|
||||
For non-gated experts, ``gup_lora`` holds the ``up_proj`` LoRA.
|
||||
"""
|
||||
# Collect ParamWrapper layers by their parameter_name
|
||||
wrappers = {}
|
||||
@@ -199,15 +195,13 @@ def _unwrap_experts_lora(experts_module):
|
||||
num_experts = getattr(base_experts, "num_experts", None)
|
||||
if num_experts is None:
|
||||
# Fallback: infer from parameter shape
|
||||
for attr in ("gate_up_proj", "up_proj"):
|
||||
param = getattr(base_experts, attr, None)
|
||||
if param is not None:
|
||||
num_experts = param.shape[0]
|
||||
break
|
||||
gup = getattr(base_experts, "gate_up_proj", None)
|
||||
if gup is not None:
|
||||
num_experts = gup.shape[0]
|
||||
|
||||
# Extract gate_up_proj or up_proj LoRA (needs A<->B swap due to transposition)
|
||||
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||
gup_lora = None
|
||||
gup_wrapper = wrappers.get("gate_up_proj") or wrappers.get("up_proj")
|
||||
gup_wrapper = wrappers.get("gate_up_proj")
|
||||
if gup_wrapper is not None:
|
||||
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
|
||||
if lora_A is not None:
|
||||
@@ -447,12 +441,10 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
Supports:
|
||||
|
||||
* **Softmax→topk routing**: OLMoE, Qwen2/3MoE, Mixtral, MiniMax
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2, NemotronH
|
||||
* **Sigmoid→topk routing**: GLM, DeepSeek V3, MiniMax M2
|
||||
* **Full-parameter training**: uses ``parallel_linear`` (base ScatterMoE)
|
||||
* **LoRA fine-tuning**: detects peft ``ParamWrapper`` on ``self.experts``,
|
||||
extracts adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
|
||||
* **Non-gated experts**: NemotronH (up_proj + down_proj, no gate_up_proj)
|
||||
* **Latent projections**: NemotronH (fc1/fc2_latent_proj wrapping experts)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -475,7 +467,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
hidden_states_flat = layer_input.view(-1, hidden_dim)
|
||||
|
||||
# ====================================================================
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3, NemotronH)
|
||||
# Shared Expert (if present, e.g. Qwen2MoE, DeepSeek V3)
|
||||
# ====================================================================
|
||||
shared_expert_output = _compute_shared_expert(self, hidden_states_flat)
|
||||
|
||||
@@ -497,22 +489,6 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||
|
||||
# ====================================================================
|
||||
# Detect non-gated experts (e.g. NemotronH: up_proj + down_proj only)
|
||||
# ====================================================================
|
||||
is_gated = hasattr(experts, "gate_up_proj")
|
||||
up_proj_attr = "gate_up_proj" if is_gated else "up_proj"
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection (NemotronH: fc1/fc2_latent_proj)
|
||||
# ====================================================================
|
||||
fc1_latent_proj = getattr(self, "fc1_latent_proj", None)
|
||||
fc2_latent_proj = getattr(self, "fc2_latent_proj", None)
|
||||
|
||||
expert_input = hidden_states_flat
|
||||
if fc1_latent_proj is not None and not isinstance(fc1_latent_proj, nn.Identity):
|
||||
expert_input = fc1_latent_proj(hidden_states_flat)
|
||||
|
||||
# ====================================================================
|
||||
# Selective expert weight dequantization
|
||||
# ====================================================================
|
||||
@@ -522,7 +498,7 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_selective = (
|
||||
getattr(self, "_use_selective_dequant", False)
|
||||
and hasattr(experts, "parametrizations")
|
||||
and up_proj_attr in experts.parametrizations
|
||||
and "gate_up_proj" in experts.parametrizations
|
||||
)
|
||||
|
||||
if use_selective:
|
||||
@@ -541,11 +517,11 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
num_experts,
|
||||
)
|
||||
# Dequantize only active experts' weights
|
||||
up_W = selective_expert_weights(
|
||||
gate_up_W = selective_expert_weights(
|
||||
experts,
|
||||
up_proj_attr,
|
||||
"gate_up_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1)
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
@@ -562,18 +538,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
sei_gup = remapped_expert_idxs
|
||||
eo_gup = compact_offsets
|
||||
else:
|
||||
up_W = getattr(experts, up_proj_attr).transpose(2, 1)
|
||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||
sei_gup = sorted_expert_idxs
|
||||
eo_gup = expert_offsets
|
||||
|
||||
# ====================================================================
|
||||
# Up projection (gated: gate_up_proj; non-gated: up_proj)
|
||||
# Gate + Up projection
|
||||
# ====================================================================
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
up_out = parallel_linear_lora(
|
||||
expert_input,
|
||||
up_W,
|
||||
gup = parallel_linear_lora(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -587,9 +563,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
use_fused_gather=True,
|
||||
)
|
||||
else:
|
||||
up_out = parallel_linear(
|
||||
expert_input,
|
||||
up_W,
|
||||
gup = parallel_linear(
|
||||
hidden_states_flat,
|
||||
gate_up_W,
|
||||
top_k,
|
||||
sei_gup,
|
||||
sorted_scattered_idxs,
|
||||
@@ -598,14 +574,8 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
grouped_out=True,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Activation: gated (act_fn(gate) * up) vs non-gated (act_fn(up))
|
||||
# ====================================================================
|
||||
if is_gated:
|
||||
gates, h = up_out.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
else:
|
||||
h = experts.act_fn(up_out)
|
||||
gates, h = gup.chunk(2, dim=-1)
|
||||
h = experts.act_fn(gates) * h
|
||||
|
||||
# ====================================================================
|
||||
# Down projection
|
||||
@@ -665,12 +635,6 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
gates=routing_weights,
|
||||
)
|
||||
|
||||
# ====================================================================
|
||||
# Optional latent projection back to hidden_size (NemotronH)
|
||||
# ====================================================================
|
||||
if fc2_latent_proj is not None and not isinstance(fc2_latent_proj, nn.Identity):
|
||||
expert_output = fc2_latent_proj(expert_output)
|
||||
|
||||
# ====================================================================
|
||||
# Combine with shared expert and reshape
|
||||
# ====================================================================
|
||||
|
||||
@@ -9,7 +9,6 @@ import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
|
||||
def setup_model_and_trainer(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[
|
||||
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||
Trainer,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
|
||||
)
|
||||
|
||||
|
||||
def _is_tui_enabled(cfg: DictDefault) -> bool:
|
||||
"""Check if TUI is enabled via config or environment variable."""
|
||||
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
|
||||
return True
|
||||
tui = cfg.get("tui")
|
||||
if tui is None:
|
||||
return False
|
||||
if isinstance(tui, bool):
|
||||
return tui
|
||||
if isinstance(tui, dict):
|
||||
return tui.get("enabled", False)
|
||||
if hasattr(tui, "enabled"):
|
||||
return tui.enabled
|
||||
return False
|
||||
|
||||
|
||||
def _get_tui_config(cfg: DictDefault) -> dict:
|
||||
"""Extract TUI config dict from cfg."""
|
||||
tui = cfg.get("tui")
|
||||
if tui is None or isinstance(tui, bool):
|
||||
return {"enabled": True}
|
||||
if isinstance(tui, dict):
|
||||
return {**tui, "enabled": True}
|
||||
if hasattr(tui, "model_dump"):
|
||||
d = tui.model_dump()
|
||||
d["enabled"] = True
|
||||
return d
|
||||
return {"enabled": True}
|
||||
|
||||
|
||||
@send_errors
|
||||
def train(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
@@ -577,6 +603,37 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
# Register TUI callback if enabled and rank 0
|
||||
tui_enabled = _is_tui_enabled(cfg)
|
||||
if tui_enabled and cfg.local_rank == 0:
|
||||
from axolotl.tui import AxolotlTUICallback
|
||||
from axolotl.tui.config import TUIConfig
|
||||
|
||||
tui_config = _get_tui_config(cfg)
|
||||
tui_config_obj = (
|
||||
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
||||
)
|
||||
|
||||
# Reuse the early-started renderer if available (started in do_train)
|
||||
early_renderer = getattr(cfg, "_tui_renderer", None)
|
||||
early_queue = getattr(cfg, "_tui_queue", None)
|
||||
|
||||
tui_callback = AxolotlTUICallback(config=tui_config_obj)
|
||||
if early_renderer is not None and early_queue is not None:
|
||||
# Reuse the already-running renderer and queue
|
||||
tui_callback._renderer = early_renderer
|
||||
tui_callback._queue = early_queue
|
||||
tui_callback._renderer_started_early = True
|
||||
trainer.add_callback(tui_callback)
|
||||
|
||||
# Stash model info so on_train_begin can emit a single unified run_info event
|
||||
tui_callback._pending_run_info = {
|
||||
"model_name": cfg.base_model or "",
|
||||
"training_mode": str(cfg.rl) if cfg.rl else "sft",
|
||||
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
}
|
||||
LOG.info("TUI dashboard enabled")
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)
|
||||
|
||||
17
src/axolotl/tui/__init__.py
Normal file
17
src/axolotl/tui/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
|
||||
|
||||
from axolotl.tui.callback import AxolotlTUICallback
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
__all__ = [
|
||||
"AxolotlTUICallback",
|
||||
"BasePanel",
|
||||
"LineParser",
|
||||
"TUIConfig",
|
||||
"TUIState",
|
||||
"register_panel",
|
||||
"register_parser",
|
||||
]
|
||||
142
src/axolotl/tui/callback.py
Normal file
142
src/axolotl/tui/callback.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.renderer import TUIRenderer
|
||||
|
||||
|
||||
class _TUILogHandler(logging.Handler):
|
||||
"""Logging handler that pushes log records into the TUI metric queue."""
|
||||
|
||||
_LEVEL_MAP = {
|
||||
logging.DEBUG: "debug",
|
||||
logging.INFO: "info",
|
||||
logging.WARNING: "warning",
|
||||
logging.ERROR: "error",
|
||||
logging.CRITICAL: "error",
|
||||
}
|
||||
|
||||
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
|
||||
super().__init__()
|
||||
level_name = min_level.upper()
|
||||
self.setLevel(getattr(logging, level_name, logging.INFO))
|
||||
self._queue = metric_queue
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = self._LEVEL_MAP.get(record.levelno, "info")
|
||||
msg = self.format(record)
|
||||
self._queue.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": level,
|
||||
"message": msg,
|
||||
}
|
||||
)
|
||||
except queue.Full:
|
||||
pass
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
|
||||
class AxolotlTUICallback(TrainerCallback):
|
||||
"""Pushes training metrics into a queue for the TUI renderer.
|
||||
|
||||
The callback never blocks on the render thread. The queue is bounded
|
||||
(maxsize=512) with put_nowait; overflow is silently dropped.
|
||||
"""
|
||||
|
||||
def __init__(self, config: TUIConfig):
|
||||
self._config = config
|
||||
self._queue: queue.Queue = queue.Queue(maxsize=4096)
|
||||
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
|
||||
self._log_handler: _TUILogHandler | None = None
|
||||
self._renderer_started_early: bool = False
|
||||
self._pending_run_info: dict | None = None
|
||||
|
||||
def _put(self, event: dict) -> None:
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
pass
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
# Send a single unified run_info event with all fields
|
||||
run_info = {
|
||||
"type": "run_info",
|
||||
"run_name": getattr(args, "run_name", "") or "",
|
||||
"total_steps": state.max_steps,
|
||||
"total_epochs": float(args.num_train_epochs)
|
||||
if args.num_train_epochs
|
||||
else 1.0,
|
||||
}
|
||||
# Merge in model_name/training_mode/world_size if stashed by train.py
|
||||
if self._pending_run_info:
|
||||
run_info.update(self._pending_run_info)
|
||||
self._pending_run_info = None
|
||||
self._put(run_info)
|
||||
|
||||
if not self._renderer_started_early:
|
||||
# Attach a logging handler to feed log messages into the events panel
|
||||
self._log_handler = _TUILogHandler(
|
||||
self._queue, min_level=self._config.log_level
|
||||
)
|
||||
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||
# Attach to both root and axolotl loggers (axolotl has propagate=False)
|
||||
logging.getLogger().addHandler(self._log_handler)
|
||||
logging.getLogger("axolotl").addHandler(self._log_handler)
|
||||
|
||||
# Start the renderer background thread
|
||||
self._renderer.start()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None:
|
||||
return
|
||||
|
||||
# Filter out non-numeric keys and internal keys
|
||||
filtered = {}
|
||||
for key, value in logs.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if isinstance(value, (int, float)):
|
||||
filtered[key] = value
|
||||
elif isinstance(value, str):
|
||||
# HF Trainer sometimes passes string-encoded numbers
|
||||
try:
|
||||
filtered[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if filtered:
|
||||
self._put({"type": "metrics", "logs": filtered})
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
self._put(
|
||||
{
|
||||
"type": "step",
|
||||
"step": state.global_step,
|
||||
"total_steps": state.max_steps,
|
||||
"epoch": state.epoch if state.epoch else 0,
|
||||
}
|
||||
)
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
pass
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
self._put({"type": "done"})
|
||||
# If renderer was started early, do_train's finally block handles stop
|
||||
if not self._renderer_started_early:
|
||||
self._renderer.stop()
|
||||
|
||||
# Remove the logging handler (only if we added it)
|
||||
if self._log_handler:
|
||||
logging.getLogger().removeHandler(self._log_handler)
|
||||
logging.getLogger("axolotl").removeHandler(self._log_handler)
|
||||
self._log_handler = None
|
||||
38
src/axolotl/tui/config.py
Normal file
38
src/axolotl/tui/config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""TUI configuration — Pydantic model for TUI settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TUIConfig(BaseModel):
|
||||
"""Configuration for the Axolotl Training TUI dashboard."""
|
||||
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Enable the TUI dashboard"},
|
||||
)
|
||||
refresh_rate: int = Field(
|
||||
default=4,
|
||||
json_schema_extra={"description": "Renders per second"},
|
||||
)
|
||||
log_level: str = Field(
|
||||
default="debug",
|
||||
json_schema_extra={"description": "Minimum log level shown in events panel"},
|
||||
)
|
||||
panels: list[str] = Field(
|
||||
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
|
||||
json_schema_extra={"description": "Ordered list of panels to display"},
|
||||
)
|
||||
hardware_poll_interval: int = Field(
|
||||
default=2,
|
||||
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
|
||||
)
|
||||
stdout_log_path: str = Field(
|
||||
default="axolotl_stdout.log",
|
||||
json_schema_extra={"description": "File path for captured stdout/stderr log"},
|
||||
)
|
||||
parser_plugins: list[str] = Field(
|
||||
default_factory=list,
|
||||
json_schema_extra={"description": "List of extra parser classes to load"},
|
||||
)
|
||||
72
src/axolotl/tui/gpu.py
Normal file
72
src/axolotl/tui/gpu.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""GPU polling wrapper around pynvml with graceful fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.tui.state import GPUStats
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_nvml_available = False
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
_nvml_available = True
|
||||
except Exception:
|
||||
LOG.debug("pynvml unavailable — GPU stats will not be shown")
|
||||
|
||||
|
||||
class GPUPoller:
|
||||
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
|
||||
|
||||
def __init__(self):
|
||||
self._device_count = 0
|
||||
if _nvml_available:
|
||||
try:
|
||||
self._device_count = pynvml.nvmlDeviceGetCount()
|
||||
except Exception:
|
||||
self._device_count = 0
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return _nvml_available and self._device_count > 0
|
||||
|
||||
def poll(self) -> list[GPUStats]:
|
||||
if not self.available:
|
||||
return []
|
||||
|
||||
stats = []
|
||||
for i in range(self._device_count):
|
||||
try:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||
name = pynvml.nvmlDeviceGetName(handle)
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("utf-8")
|
||||
|
||||
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
temp = pynvml.nvmlDeviceGetTemperature(
|
||||
handle, pynvml.NVML_TEMPERATURE_GPU
|
||||
)
|
||||
|
||||
try:
|
||||
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
|
||||
except Exception:
|
||||
power = None
|
||||
|
||||
stats.append(
|
||||
GPUStats(
|
||||
id=i,
|
||||
name=name,
|
||||
util_pct=util.gpu,
|
||||
vram_used_gb=mem.used / (1024**3),
|
||||
vram_total_gb=mem.total / (1024**3),
|
||||
temp_c=temp,
|
||||
power_w=power,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
LOG.debug("Error polling GPU device %d", i, exc_info=True)
|
||||
return stats
|
||||
196
src/axolotl/tui/io_capture.py
Normal file
196
src/axolotl/tui/io_capture.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import sys
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import IO
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parser registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_parser_registry: list[type[LineParser]] = []
|
||||
|
||||
|
||||
def register_parser(cls: type[LineParser]) -> type[LineParser]:
|
||||
"""Decorator to register a LineParser subclass."""
|
||||
if cls not in _parser_registry:
|
||||
_parser_registry.append(cls)
|
||||
return cls
|
||||
|
||||
|
||||
def get_registered_parsers() -> list[type[LineParser]]:
|
||||
return list(_parser_registry)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base LineParser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LineParser(ABC):
|
||||
"""Base class for stdout/stderr line parsers."""
|
||||
|
||||
priority: int = 50
|
||||
name: str = ""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
"""Parse a single captured line.
|
||||
|
||||
Args:
|
||||
line: one line of captured output, trailing newline stripped.
|
||||
source: "stdout" or "stderr".
|
||||
|
||||
Returns:
|
||||
List of event dicts to push onto the metric queue.
|
||||
Return [] if this line is not relevant.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ParserChain
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ParserChain:
|
||||
def __init__(self):
|
||||
self._parsers: list[LineParser] = []
|
||||
|
||||
def register(self, parser: LineParser) -> None:
|
||||
self._parsers.append(parser)
|
||||
self._parsers.sort(key=lambda p: p.priority)
|
||||
|
||||
def parse(self, line: str, source: str = "stdout") -> list[dict]:
|
||||
events: list[dict] = []
|
||||
for parser in self._parsers:
|
||||
events.extend(parser.parse(line, source))
|
||||
return events
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IOCapture — OS-level fd redirect to pipe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class IOCapture:
|
||||
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
|
||||
passes lines through a ParserChain, and tees to a log file."""
|
||||
|
||||
def __init__(
|
||||
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
|
||||
):
|
||||
self._parser_chain = parser_chain
|
||||
self._queue = metric_queue
|
||||
self._log_path = log_path
|
||||
self._log_file: IO[str] | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._read_fd: int | None = None
|
||||
self._write_fd: int | None = None
|
||||
self._saved_stdout_fd: int | None = None
|
||||
self._saved_stderr_fd: int | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
# Write run-start separator
|
||||
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
|
||||
self._log_file.write(
|
||||
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
|
||||
)
|
||||
self._log_file.flush()
|
||||
|
||||
# OS-level pipe
|
||||
self._read_fd, self._write_fd = os.pipe()
|
||||
|
||||
# Save originals
|
||||
self._saved_stdout_fd = os.dup(1)
|
||||
self._saved_stderr_fd = os.dup(2)
|
||||
|
||||
# Redirect both stdout and stderr into the write end
|
||||
os.dup2(self._write_fd, 1)
|
||||
os.dup2(self._write_fd, 2)
|
||||
os.close(self._write_fd) # write end now held by fds 1 and 2
|
||||
|
||||
# Also redirect Python-level handles
|
||||
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
|
||||
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
|
||||
|
||||
# Drain thread
|
||||
self._thread = threading.Thread(target=self._drain, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
# Restore fds — closes the write end, causing reader to see EOF
|
||||
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
os.dup2(self._saved_stdout_fd, 1)
|
||||
os.dup2(self._saved_stderr_fd, 2)
|
||||
os.close(self._saved_stdout_fd)
|
||||
os.close(self._saved_stderr_fd)
|
||||
self._saved_stdout_fd = None
|
||||
self._saved_stderr_fd = None
|
||||
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=2.0)
|
||||
if self._thread.is_alive():
|
||||
logging.getLogger(__name__).warning(
|
||||
"IO capture thread did not exit after 2s"
|
||||
)
|
||||
self._thread = None
|
||||
|
||||
if self._log_file is not None:
|
||||
self._log_file.close()
|
||||
self._log_file = None
|
||||
|
||||
def _drain(self) -> None:
|
||||
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
|
||||
# which use \r for in-place updates without \n
|
||||
assert self._read_fd is not None, "_drain called before start()"
|
||||
with os.fdopen(self._read_fd, "rb") as pipe:
|
||||
buf = b""
|
||||
while True:
|
||||
chunk = pipe.read(4096)
|
||||
if not chunk:
|
||||
# EOF — process remaining buffer
|
||||
if buf:
|
||||
self._process_line(buf.decode("utf-8", errors="replace"))
|
||||
break
|
||||
buf += chunk
|
||||
# Split on \n or \r
|
||||
while b"\n" in buf or b"\r" in buf:
|
||||
# Find the earliest delimiter
|
||||
idx_n = buf.find(b"\n")
|
||||
idx_r = buf.find(b"\r")
|
||||
if idx_n == -1:
|
||||
idx = idx_r
|
||||
elif idx_r == -1:
|
||||
idx = idx_n
|
||||
else:
|
||||
idx = min(idx_n, idx_r)
|
||||
line = buf[:idx].decode("utf-8", errors="replace")
|
||||
buf = buf[idx + 1 :]
|
||||
# Handle \r\n as single delimiter
|
||||
if buf.startswith(b"\n"):
|
||||
buf = buf[1:]
|
||||
if line:
|
||||
self._process_line(line)
|
||||
|
||||
def _process_line(self, line: str) -> None:
|
||||
line = line.rstrip()
|
||||
if not line:
|
||||
return
|
||||
if self._log_file:
|
||||
self._log_file.write(line + "\n")
|
||||
self._log_file.flush()
|
||||
for event in self._parser_chain.parse(line):
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
pass
|
||||
63
src/axolotl/tui/panels/__init__.py
Normal file
63
src/axolotl/tui/panels/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Panel registry and base class for TUI panels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from rich.console import RenderableType
|
||||
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Panel registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_panel_registry: dict[str, type[BasePanel]] = {}
|
||||
|
||||
|
||||
def register_panel(position: str = "bottom", weight: int = 50):
|
||||
"""Decorator to register a panel class with position and weight."""
|
||||
|
||||
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
|
||||
cls.position = position
|
||||
cls.weight = weight
|
||||
_panel_registry[cls.name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_registered_panels() -> dict[str, type[BasePanel]]:
|
||||
return dict(_panel_registry)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BasePanel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BasePanel(ABC):
|
||||
name: str = ""
|
||||
position: str = "bottom"
|
||||
weight: int = 50
|
||||
min_height: int = 4
|
||||
max_height: int | None = None
|
||||
modes: list[str] = ["*"]
|
||||
|
||||
@abstractmethod
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
"""Return a rich renderable. Called every tick."""
|
||||
...
|
||||
|
||||
def on_event(self, event: dict) -> None: # noqa: B027
|
||||
"""Optional: react to raw metric events before state is merged."""
|
||||
pass
|
||||
|
||||
|
||||
# Auto-import built-in panels to trigger registration
|
||||
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
|
||||
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
|
||||
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
|
||||
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
|
||||
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
|
||||
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401
|
||||
61
src/axolotl/tui/panels/completions.py
Normal file
61
src/axolotl/tui/panels/completions.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""CompletionsPanel — shows recent RL/log_completions samples."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
|
||||
def _truncate(s: str, maxlen: int = 60) -> str:
|
||||
return s[:maxlen] + "…" if len(s) > maxlen else s
|
||||
|
||||
|
||||
@register_panel(position="bottom", weight=20)
|
||||
class CompletionsPanel(BasePanel):
|
||||
name = "completions"
|
||||
min_height = 6
|
||||
modes = ["grpo", "dpo"]
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
if "*" not in self.modes and state.training_mode not in self.modes:
|
||||
return Text("")
|
||||
|
||||
if not state.completions:
|
||||
return Panel(
|
||||
Text("No completions yet...", style="dim"),
|
||||
title="Completions",
|
||||
border_style="magenta",
|
||||
)
|
||||
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold",
|
||||
expand=True,
|
||||
box=None,
|
||||
pad_edge=False,
|
||||
)
|
||||
table.add_column("step", justify="right", width=6)
|
||||
table.add_column("prompt", no_wrap=False, max_width=40)
|
||||
table.add_column("completion", no_wrap=False, max_width=40)
|
||||
table.add_column("reward", justify="right", width=8)
|
||||
table.add_column("adv", justify="right", width=8)
|
||||
|
||||
for sample in list(state.completions)[-5:]:
|
||||
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
|
||||
adv_str = (
|
||||
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
|
||||
)
|
||||
table.add_row(
|
||||
str(sample.step),
|
||||
_truncate(sample.prompt),
|
||||
_truncate(sample.completion),
|
||||
reward_str,
|
||||
adv_str,
|
||||
)
|
||||
|
||||
return Panel(table, title="Completions", border_style="magenta")
|
||||
34
src/axolotl/tui/panels/debug.py
Normal file
34
src/axolotl/tui/panels/debug.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
|
||||
@register_panel(position="bottom", weight=30)
|
||||
class DebugPanel(BasePanel):
|
||||
name = "debug"
|
||||
min_height = 6
|
||||
max_height = 10
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
lines = Text()
|
||||
# Show last 8 debug-level log lines
|
||||
debug_lines = [
|
||||
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
|
||||
][-8:]
|
||||
for log_line in debug_lines:
|
||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||
lines.append(f"[{ts}] ", style="dim")
|
||||
lines.append(log_line.message[:200], style="dim")
|
||||
lines.append("\n")
|
||||
|
||||
if not debug_lines:
|
||||
lines = Text("No debug messages yet...", style="dim")
|
||||
|
||||
return Panel(lines, title="Debug", border_style="dim")
|
||||
45
src/axolotl/tui/panels/events.py
Normal file
45
src/axolotl/tui/panels/events.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""EventsPanel — scrolling log of recent events, color-coded by level."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
_LEVEL_STYLES = {
|
||||
"debug": "dim",
|
||||
"info": "",
|
||||
"warning": "yellow",
|
||||
"error": "red bold",
|
||||
"critical": "red bold",
|
||||
}
|
||||
|
||||
|
||||
@register_panel(position="bottom", weight=10)
|
||||
class EventsPanel(BasePanel):
|
||||
name = "events"
|
||||
min_height = 8
|
||||
max_height = 20
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
lines = Text()
|
||||
# Show last 15 non-debug log lines (debug goes to DebugPanel)
|
||||
recent = [
|
||||
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
|
||||
][-15:]
|
||||
for log_line in recent:
|
||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||
level = log_line.level.upper()
|
||||
style = _LEVEL_STYLES.get(log_line.level, "")
|
||||
lines.append(f"[{ts}] ", style="dim")
|
||||
lines.append(f"[{level}] ", style=style or "")
|
||||
lines.append(log_line.message[:200], style=style or "")
|
||||
lines.append("\n")
|
||||
|
||||
if not recent:
|
||||
lines = Text("No events yet...", style="dim")
|
||||
|
||||
return Panel(lines, title="Events", border_style="yellow")
|
||||
80
src/axolotl/tui/panels/hardware.py
Normal file
80
src/axolotl/tui/panels/hardware.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""HardwarePanel — per-GPU stats via pynvml."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
_BAR_FULL = "█"
|
||||
_BAR_EMPTY = "░"
|
||||
|
||||
|
||||
def _util_bar(pct: float, width: int = 6) -> Text:
|
||||
filled = int(pct / 100 * width)
|
||||
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
|
||||
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
|
||||
return Text.assemble((bar, color), f" {pct:3.0f}%")
|
||||
|
||||
|
||||
@register_panel(position="right", weight=10)
|
||||
class HardwarePanel(BasePanel):
|
||||
name = "hardware"
|
||||
min_height = 6
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
if not state.gpus:
|
||||
return Panel(
|
||||
Text("GPU stats unavailable", style="dim"),
|
||||
title="Hardware",
|
||||
border_style="green",
|
||||
)
|
||||
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold",
|
||||
expand=True,
|
||||
box=None,
|
||||
pad_edge=False,
|
||||
)
|
||||
table.add_column("id", justify="right", width=3)
|
||||
table.add_column("util", no_wrap=True)
|
||||
table.add_column("vram", no_wrap=True)
|
||||
table.add_column("°C", justify="right", width=4)
|
||||
table.add_column("W", justify="right", width=5)
|
||||
|
||||
total_vram_used = 0.0
|
||||
total_vram_total = 0.0
|
||||
total_util = 0.0
|
||||
|
||||
for gpu in state.gpus:
|
||||
total_vram_used += gpu.vram_used_gb
|
||||
total_vram_total += gpu.vram_total_gb
|
||||
total_util += gpu.util_pct
|
||||
|
||||
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
|
||||
table.add_row(
|
||||
str(gpu.id),
|
||||
_util_bar(gpu.util_pct),
|
||||
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
|
||||
str(gpu.temp_c),
|
||||
power_str,
|
||||
)
|
||||
|
||||
# Footer with aggregates
|
||||
n = len(state.gpus)
|
||||
if n > 1:
|
||||
avg_util = total_util / n
|
||||
table.add_row(
|
||||
"Σ",
|
||||
Text(f"avg {avg_util:.0f}%", style="dim"),
|
||||
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
return Panel(table, title="Hardware", border_style="green")
|
||||
73
src/axolotl/tui/panels/progress.py
Normal file
73
src/axolotl/tui/panels/progress.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.progress import BarColumn, Progress, TextColumn
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
|
||||
def _fmt_time(seconds: float | None) -> str:
|
||||
if seconds is None or seconds < 0:
|
||||
return "--:--:--"
|
||||
h = int(seconds) // 3600
|
||||
m = (int(seconds) % 3600) // 60
|
||||
s = int(seconds) % 60
|
||||
return f"{h}:{m:02d}:{s:02d}"
|
||||
|
||||
|
||||
def _fmt_eta(seconds: float | None) -> str:
|
||||
if seconds is None or seconds < 0:
|
||||
return "eta --"
|
||||
h = int(seconds) // 3600
|
||||
m = (int(seconds) % 3600) // 60
|
||||
if h > 0:
|
||||
return f"eta {h}h{m:02d}m"
|
||||
return f"eta {m}m{int(seconds) % 60:02d}s"
|
||||
|
||||
|
||||
@register_panel(position="top", weight=10)
|
||||
class ProgressPanel(BasePanel):
|
||||
name = "progress"
|
||||
min_height = 3
|
||||
max_height = 3
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
pct = (
|
||||
(state.current_step / state.total_steps * 100)
|
||||
if state.total_steps > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
# Header line
|
||||
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
|
||||
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
|
||||
header = Text.assemble(
|
||||
("● ", "bold green"),
|
||||
("AXOLOTL", "bold cyan"),
|
||||
f" {mode_upper} · {model_short} ",
|
||||
(
|
||||
f"{state.current_step} / {state.total_steps}",
|
||||
"bold",
|
||||
),
|
||||
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
|
||||
)
|
||||
|
||||
# Progress bar
|
||||
progress = Progress(
|
||||
TextColumn(""),
|
||||
BarColumn(bar_width=None),
|
||||
TextColumn("{task.percentage:>3.0f}%"),
|
||||
expand=True,
|
||||
)
|
||||
task = progress.add_task("", total=state.total_steps or 1)
|
||||
progress.update(task, completed=state.current_step)
|
||||
|
||||
table = Table.grid(expand=True)
|
||||
table.add_row(header)
|
||||
table.add_row(progress)
|
||||
return table
|
||||
97
src/axolotl/tui/panels/training.py
Normal file
97
src/axolotl/tui/panels/training.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""TrainingPanel — live scalar metrics table with loss sparkline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
# Braille sparkline characters (8 levels)
|
||||
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
|
||||
|
||||
|
||||
def _sparkline(values: list[float] | None, width: int = 20) -> str:
|
||||
if not values or len(values) < 2:
|
||||
return ""
|
||||
vals = list(values)[-width:]
|
||||
lo, hi = min(vals), max(vals)
|
||||
rng = hi - lo if hi != lo else 1.0
|
||||
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
|
||||
|
||||
|
||||
# Known key ordering and formatting
|
||||
_KNOWN_KEYS: list[tuple[str, str, str]] = [
|
||||
("loss", "loss", ".4f"),
|
||||
("grad_norm", "grad norm", ".3f"),
|
||||
("learning_rate", "lr", ".2e"),
|
||||
("tokens_per_second", "tok/s", ".1f"),
|
||||
("samples_per_second", "samples/s", ".1f"),
|
||||
("mfu", "MFU", ".1f"),
|
||||
# RL-specific
|
||||
("rewards_mean", "rewards/mean", ".4f"),
|
||||
("rewards_std", "rewards/std", ".4f"),
|
||||
("kl_divergence", "KL", ".4f"),
|
||||
("clip_ratio", "clip ratio", ".3f"),
|
||||
("queue_size", "queue", "d"),
|
||||
]
|
||||
|
||||
|
||||
@register_panel(position="left", weight=10)
|
||||
class TrainingPanel(BasePanel):
|
||||
name = "training"
|
||||
min_height = 8
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
table = Table(
|
||||
show_header=True,
|
||||
header_style="bold",
|
||||
expand=True,
|
||||
box=None,
|
||||
pad_edge=False,
|
||||
)
|
||||
table.add_column("metric", style="cyan", no_wrap=True)
|
||||
table.add_column("value", justify="right")
|
||||
table.add_column("trend", justify="left", no_wrap=True)
|
||||
|
||||
for attr, label, fmt in _KNOWN_KEYS:
|
||||
val = getattr(state, attr, None)
|
||||
if val is None:
|
||||
# Also check extra dict
|
||||
val = state.extra.get(attr)
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
formatted = f"{val:{fmt}}"
|
||||
except (ValueError, TypeError):
|
||||
formatted = str(val)
|
||||
|
||||
trend = ""
|
||||
if attr == "loss":
|
||||
trend = _sparkline(list(state.loss_history))
|
||||
|
||||
table.add_row(label, formatted, trend)
|
||||
|
||||
# Any extra keys not in _KNOWN_KEYS
|
||||
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
|
||||
for key, val in sorted(state.extra.items()):
|
||||
if key in known_attrs or val is None:
|
||||
continue
|
||||
try:
|
||||
formatted = f"{val:.4f}"
|
||||
except (ValueError, TypeError):
|
||||
formatted = str(val)
|
||||
table.add_row(key, formatted, "")
|
||||
|
||||
if table.row_count == 0:
|
||||
return Panel(
|
||||
Text("Waiting for first log step...", style="dim"),
|
||||
title="Training",
|
||||
border_style="blue",
|
||||
)
|
||||
|
||||
return Panel(table, title="Training", border_style="blue")
|
||||
7
src/axolotl/tui/parsers/__init__.py
Normal file
7
src/axolotl/tui/parsers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
|
||||
|
||||
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
|
||||
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
|
||||
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
|
||||
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
|
||||
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
||||
29
src/axolotl/tui/parsers/deepspeed.py
Normal file
29
src/axolotl/tui/parsers/deepspeed.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
|
||||
|
||||
@register_parser
|
||||
class DeepSpeedParser(LineParser):
|
||||
priority = 20
|
||||
name = "deepspeed"
|
||||
|
||||
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
|
||||
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
events: list[dict] = []
|
||||
if m := self._SAMPLES_RE.search(line):
|
||||
events.append(
|
||||
{
|
||||
"type": "metrics",
|
||||
"logs": {"samples_per_second": float(m.group(1))},
|
||||
}
|
||||
)
|
||||
if m := self._STAGE_RE.search(line):
|
||||
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
|
||||
return events
|
||||
27
src/axolotl/tui/parsers/nccl.py
Normal file
27
src/axolotl/tui/parsers/nccl.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
|
||||
|
||||
@register_parser
|
||||
class NCCLErrorParser(LineParser):
|
||||
priority = 10
|
||||
name = "nccl_error"
|
||||
|
||||
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
if self._RE.search(line):
|
||||
return [
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "error",
|
||||
"message": f"⚠ NCCL: {line}",
|
||||
},
|
||||
{"type": "alert", "severity": "error", "message": line},
|
||||
]
|
||||
return []
|
||||
37
src/axolotl/tui/parsers/raw_log.py
Normal file
37
src/axolotl/tui/parsers/raw_log.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""RawLogParser — catches every line as a log_line event."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
|
||||
|
||||
@register_parser
|
||||
class RawLogParser(LineParser):
|
||||
priority = 99
|
||||
name = "raw_log"
|
||||
|
||||
_LOG_RE = re.compile(
|
||||
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
|
||||
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
||||
r"\s*[-]\s*(?P<msg>.+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Filter out tqdm progress bar lines and other noisy output
|
||||
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
|
||||
_EMPTY_RE = re.compile(r"^\s*$")
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
# Skip empty lines and tqdm progress bar updates
|
||||
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
|
||||
return []
|
||||
|
||||
m = self._LOG_RE.match(line)
|
||||
level = (
|
||||
m.group("level").lower()
|
||||
if m
|
||||
else ("error" if source == "stderr" else "info")
|
||||
)
|
||||
return [{"type": "log_line", "level": level, "message": line}]
|
||||
26
src/axolotl/tui/parsers/torch_compile.py
Normal file
26
src/axolotl/tui/parsers/torch_compile.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
|
||||
|
||||
@register_parser
|
||||
class TorchCompileParser(LineParser):
|
||||
priority = 20
|
||||
name = "torch_compile"
|
||||
|
||||
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
if self._RE.search(line):
|
||||
return [
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "warning",
|
||||
"message": f"⚡ compile: {line}",
|
||||
}
|
||||
]
|
||||
return []
|
||||
86
src/axolotl/tui/parsers/tqdm.py
Normal file
86
src/axolotl/tui/parsers/tqdm.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from axolotl.tui.io_capture import LineParser, register_parser
|
||||
|
||||
|
||||
@register_parser
|
||||
class TqdmParser(LineParser):
|
||||
priority = 15
|
||||
name = "tqdm"
|
||||
|
||||
# Match tqdm-style progress lines, e.g.:
|
||||
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
|
||||
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
|
||||
# 0%| | 0/30 [00:00<?, ?it/s]
|
||||
_TQDM_RE = re.compile(
|
||||
r"(?P<desc>.*?)\s*"
|
||||
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
|
||||
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
|
||||
r"\s*\[(?P<elapsed>[^\]]*)\]"
|
||||
)
|
||||
|
||||
# Also match simpler forms like:
|
||||
# Fetching 0 files: 0it [00:00, ?it/s]
|
||||
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
m = self._TQDM_RE.search(line)
|
||||
if m:
|
||||
desc = m.group("desc").strip().rstrip(":")
|
||||
pct = int(m.group("pct"))
|
||||
current = int(m.group("current").replace(",", ""))
|
||||
total = int(m.group("total").replace(",", ""))
|
||||
|
||||
events: list[dict] = []
|
||||
|
||||
# Surface as a log line with progress info
|
||||
if pct == 100 or pct == 0 or pct % 25 == 0:
|
||||
msg = (
|
||||
f"[{desc}] {pct}% ({current}/{total})"
|
||||
if desc
|
||||
else f"{pct}% ({current}/{total})"
|
||||
)
|
||||
events.append(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": msg,
|
||||
}
|
||||
)
|
||||
|
||||
# Also emit as a progress metric
|
||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||
if not cleaned_desc:
|
||||
cleaned_desc = "progress"
|
||||
events.append(
|
||||
{
|
||||
"type": "metrics",
|
||||
"logs": {
|
||||
f"progress/{cleaned_desc}": pct / 100.0,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
# Fallback: try simpler fetch-style progress lines
|
||||
m = self._FETCH_RE.search(line)
|
||||
if m:
|
||||
desc = m.group("desc").strip().rstrip(":")
|
||||
current = int(m.group("current"))
|
||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||
if not cleaned_desc:
|
||||
cleaned_desc = "fetch"
|
||||
return [
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] {current}" if desc else f"{current}",
|
||||
}
|
||||
]
|
||||
|
||||
return []
|
||||
449
src/axolotl/tui/renderer.py
Normal file
449
src/axolotl/tui/renderer.py
Normal file
@@ -0,0 +1,449 @@
|
||||
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.gpu import GPUPoller
|
||||
from axolotl.tui.io_capture import (
|
||||
IOCapture,
|
||||
ParserChain,
|
||||
get_registered_parsers,
|
||||
)
|
||||
from axolotl.tui.panels import BasePanel, get_registered_panels
|
||||
from axolotl.tui.state import CompletionSample, LogLine, TUIState
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TUIRenderer:
|
||||
"""Background thread that renders the TUI dashboard using rich.live.Live."""
|
||||
|
||||
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
|
||||
self._config = config
|
||||
self._queue = metric_queue
|
||||
self._state = TUIState()
|
||||
self._gpu_poller = GPUPoller()
|
||||
self._panels: list[BasePanel] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._io_capture: IOCapture | None = None
|
||||
self._parser_chain: ParserChain | None = None
|
||||
|
||||
def _init_panels(self) -> None:
|
||||
registry = get_registered_panels()
|
||||
for panel_name in self._config.panels:
|
||||
if panel_name in registry:
|
||||
self._panels.append(registry[panel_name]())
|
||||
|
||||
def _init_parser_chain(self) -> None:
|
||||
# Ensure built-in parsers are imported so @register_parser decorators fire
|
||||
import axolotl.tui.parsers # noqa: F401
|
||||
|
||||
self._parser_chain = ParserChain()
|
||||
# Register all built-in parsers
|
||||
for parser_cls in get_registered_parsers():
|
||||
self._parser_chain.register(parser_cls())
|
||||
|
||||
# Load plugin parsers
|
||||
for plugin_spec in self._config.parser_plugins:
|
||||
try:
|
||||
if "::" in plugin_spec:
|
||||
# file path :: class name
|
||||
file_path, class_name = plugin_spec.split("::", 1)
|
||||
import importlib.util
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"custom_parser", file_path
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load spec for {file_path}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
else:
|
||||
# dotted module path
|
||||
module_path, class_name = plugin_spec.rsplit(".", 1)
|
||||
mod = importlib.import_module(module_path)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
self._parser_chain.register(parser_cls())
|
||||
except Exception as exc:
|
||||
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
|
||||
|
||||
def _build_layout(self) -> Layout:
|
||||
layout = Layout()
|
||||
|
||||
top_panels = [p for p in self._panels if p.position == "top"]
|
||||
left_panels = [p for p in self._panels if p.position == "left"]
|
||||
right_panels = [p for p in self._panels if p.position == "right"]
|
||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
||||
|
||||
sections = []
|
||||
|
||||
if top_panels:
|
||||
layout_top = Layout(name="top", size=3)
|
||||
sections.append(layout_top)
|
||||
|
||||
if left_panels or right_panels:
|
||||
layout_middle = Layout(name="middle", ratio=3)
|
||||
middle_parts = []
|
||||
if left_panels:
|
||||
middle_parts.append(Layout(name="left", ratio=1))
|
||||
if right_panels:
|
||||
middle_parts.append(Layout(name="right", ratio=1))
|
||||
if middle_parts:
|
||||
layout_middle.split_row(*middle_parts)
|
||||
sections.append(layout_middle)
|
||||
|
||||
if bottom_panels:
|
||||
layout_bottom = Layout(name="bottom", ratio=2)
|
||||
if len(bottom_panels) > 1:
|
||||
layout_bottom.split_row(
|
||||
*[
|
||||
Layout(name=f"bottom_{i}", ratio=1)
|
||||
for i in range(len(bottom_panels))
|
||||
]
|
||||
)
|
||||
sections.append(layout_bottom)
|
||||
|
||||
if sections:
|
||||
layout.split_column(*sections)
|
||||
|
||||
return layout
|
||||
|
||||
def _update_layout(self, layout: Layout) -> None:
|
||||
top_panels = [p for p in self._panels if p.position == "top"]
|
||||
left_panels = [p for p in self._panels if p.position == "left"]
|
||||
right_panels = [p for p in self._panels if p.position == "right"]
|
||||
bottom_panels = [p for p in self._panels if p.position == "bottom"]
|
||||
|
||||
if top_panels:
|
||||
layout["top"].update(top_panels[0].render(self._state))
|
||||
|
||||
if left_panels:
|
||||
layout["left"].update(left_panels[0].render(self._state))
|
||||
|
||||
if right_panels:
|
||||
layout["right"].update(right_panels[0].render(self._state))
|
||||
|
||||
if bottom_panels:
|
||||
if len(bottom_panels) == 1:
|
||||
layout["bottom"].update(bottom_panels[0].render(self._state))
|
||||
else:
|
||||
for i, panel in enumerate(bottom_panels):
|
||||
layout[f"bottom_{i}"].update(panel.render(self._state))
|
||||
|
||||
def _drain_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self._queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Dispatch event to panels first
|
||||
for panel in self._panels:
|
||||
panel.on_event(event)
|
||||
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type == "metrics":
|
||||
logs = event.get("logs", {})
|
||||
self._apply_metrics(logs)
|
||||
|
||||
elif event_type == "step":
|
||||
self._state.current_step = event.get("step", self._state.current_step)
|
||||
self._state.total_steps = event.get(
|
||||
"total_steps", self._state.total_steps
|
||||
)
|
||||
self._state.current_epoch = event.get(
|
||||
"epoch", self._state.current_epoch
|
||||
)
|
||||
now = time.time()
|
||||
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
|
||||
if self._state.current_step > 0 and self._state.total_steps > 0:
|
||||
rate = self._state.elapsed_seconds / self._state.current_step
|
||||
remaining = self._state.total_steps - self._state.current_step
|
||||
self._state.eta_seconds = rate * remaining
|
||||
|
||||
elif event_type == "log_line":
|
||||
level = event.get("level", "info")
|
||||
message = event.get("message", "")
|
||||
self._state.log_lines.append(
|
||||
LogLine(
|
||||
timestamp=datetime.now(),
|
||||
level=level,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
|
||||
elif event_type == "completion":
|
||||
self._state.completions.append(
|
||||
CompletionSample(
|
||||
step=event.get("step", 0),
|
||||
prompt=event.get("prompt", ""),
|
||||
completion=event.get("completion", ""),
|
||||
reward=event.get("reward"),
|
||||
advantage=event.get("advantage"),
|
||||
)
|
||||
)
|
||||
|
||||
elif event_type == "run_info":
|
||||
if "run_name" in event:
|
||||
self._state.run_name = event["run_name"]
|
||||
if "model_name" in event:
|
||||
self._state.model_name = event["model_name"]
|
||||
if "training_mode" in event:
|
||||
self._state.training_mode = event["training_mode"]
|
||||
if "world_size" in event:
|
||||
self._state.world_size = event["world_size"]
|
||||
if "total_steps" in event:
|
||||
self._state.total_steps = event["total_steps"]
|
||||
if "total_epochs" in event:
|
||||
self._state.total_epochs = event["total_epochs"]
|
||||
if "zero_stage" in event:
|
||||
self._state.zero_stage = event["zero_stage"]
|
||||
|
||||
elif event_type == "done":
|
||||
self._stop_event.set()
|
||||
|
||||
def _apply_metrics(self, logs: dict[str, Any]) -> None:
|
||||
metric_map = {
|
||||
"loss": "loss",
|
||||
"grad_norm": "grad_norm",
|
||||
"learning_rate": "learning_rate",
|
||||
"tokens_per_second": "tokens_per_second",
|
||||
"samples_per_second": "samples_per_second",
|
||||
"mfu": "mfu",
|
||||
"rewards/mean": "rewards_mean",
|
||||
"rewards_mean": "rewards_mean",
|
||||
"rewards/std": "rewards_std",
|
||||
"rewards_std": "rewards_std",
|
||||
"kl": "kl_divergence",
|
||||
"kl_divergence": "kl_divergence",
|
||||
"clip_ratio": "clip_ratio",
|
||||
"queue_size": "queue_size",
|
||||
}
|
||||
|
||||
for key, value in logs.items():
|
||||
if key in metric_map:
|
||||
setattr(self._state, metric_map[key], value)
|
||||
else:
|
||||
self._state.extra[key] = value
|
||||
|
||||
if "loss" in logs and logs["loss"] is not None:
|
||||
self._state.loss_history.append(logs["loss"])
|
||||
|
||||
def start(self) -> None:
|
||||
self._init_panels()
|
||||
self._init_parser_chain()
|
||||
|
||||
# Set up I/O capture
|
||||
assert self._parser_chain is not None, "_init_parser_chain must be called first"
|
||||
self._io_capture = IOCapture(
|
||||
log_path=self._config.stdout_log_path,
|
||||
parser_chain=self._parser_chain,
|
||||
metric_queue=self._queue,
|
||||
)
|
||||
|
||||
# Monkeypatch tqdm to suppress terminal output and route through our queue.
|
||||
# This prevents tqdm progress bars from flickering through the TUI and
|
||||
# ensures all progress events appear in the Events panel.
|
||||
self._install_tqdm_hook()
|
||||
|
||||
self._io_capture_ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._io_capture_ready.wait(timeout=5.0)
|
||||
|
||||
def _install_tqdm_hook(self) -> None:
|
||||
"""Replace tqdm's display method to route updates through TUI queue."""
|
||||
try:
|
||||
import io
|
||||
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
|
||||
q = self._queue
|
||||
self._tqdm_parser = None
|
||||
# Find our tqdm parser in the chain
|
||||
for p in self._parser_chain._parsers if self._parser_chain else []:
|
||||
if p.name == "tqdm":
|
||||
self._tqdm_parser = p
|
||||
break
|
||||
|
||||
# Save originals for restore
|
||||
self._orig_tqdm_class_auto = tqdm.auto.tqdm
|
||||
self._orig_tqdm_class_tqdm = tqdm.tqdm
|
||||
self._orig_tqdm_class_std = tqdm.std.tqdm
|
||||
|
||||
class TUITqdm(tqdm.tqdm):
|
||||
"""tqdm subclass that sends progress to TUI instead of terminal."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Force output to devnull so nothing reaches the terminal
|
||||
kwargs["file"] = io.StringIO()
|
||||
kwargs["dynamic_ncols"] = False
|
||||
kwargs["ncols"] = 80
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def display(self, msg=None, pos=None):
|
||||
# Build a progress string and push to queue
|
||||
if self.total and self.total > 0:
|
||||
pct = self.n / self.total * 100
|
||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||
# Emit events at milestones or at low frequency
|
||||
is_milestone = (
|
||||
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
|
||||
)
|
||||
if is_milestone:
|
||||
try:
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
|
||||
if desc
|
||||
else f"{pct:.0f}% ({self.n}/{self.total})",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
metric_key = (
|
||||
f"progress/{desc.lower().replace(' ', '_')}"
|
||||
if desc
|
||||
else "progress/unknown"
|
||||
)
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "metrics",
|
||||
"logs": {metric_key: pct / 100.0},
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
# Emit final completion event
|
||||
if self.total and self.total > 0 and self.n > 0:
|
||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||
try:
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
|
||||
if desc
|
||||
else f"100% ({self.total}/{self.total}) done",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
super().close()
|
||||
|
||||
# Replace tqdm globally
|
||||
tqdm.auto.tqdm = TUITqdm
|
||||
tqdm.tqdm = TUITqdm
|
||||
# Also patch tqdm.std which some libraries use directly
|
||||
tqdm.std.tqdm = TUITqdm
|
||||
self._tui_tqdm_cls = TUITqdm
|
||||
|
||||
except Exception as exc:
|
||||
LOG.debug(f"Failed to install tqdm hook: {exc}")
|
||||
|
||||
def _uninstall_tqdm_hook(self) -> None:
|
||||
"""Restore original tqdm."""
|
||||
try:
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
|
||||
if hasattr(self, "_orig_tqdm_class_auto"):
|
||||
tqdm.auto.tqdm = self._orig_tqdm_class_auto
|
||||
if hasattr(self, "_orig_tqdm_class_tqdm"):
|
||||
tqdm.tqdm = self._orig_tqdm_class_tqdm
|
||||
if hasattr(self, "_orig_tqdm_class_std"):
|
||||
tqdm.std.tqdm = self._orig_tqdm_class_std
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
self._uninstall_tqdm_hook()
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
def _run(self) -> None:
|
||||
import os
|
||||
|
||||
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
|
||||
# This ensures rich.live.Live writes to the terminal, not the pipe.
|
||||
saved_tty_fd = os.dup(1)
|
||||
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
|
||||
console = Console(file=tty_file)
|
||||
|
||||
layout = self._build_layout()
|
||||
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
|
||||
gpu_poll_counter = 0
|
||||
gpu_poll_ticks = max(
|
||||
1, int(self._config.hardware_poll_interval / tick_interval)
|
||||
)
|
||||
|
||||
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
|
||||
if self._io_capture:
|
||||
self._io_capture.start()
|
||||
|
||||
# Signal that IO capture is live so start() can return
|
||||
if hasattr(self, "_io_capture_ready"):
|
||||
self._io_capture_ready.set()
|
||||
|
||||
try:
|
||||
with Live(
|
||||
layout,
|
||||
console=console,
|
||||
refresh_per_second=self._config.refresh_rate,
|
||||
screen=True,
|
||||
redirect_stdout=False,
|
||||
redirect_stderr=False,
|
||||
) as live:
|
||||
while not self._stop_event.is_set():
|
||||
self._drain_queue()
|
||||
|
||||
# Poll GPU stats periodically
|
||||
gpu_poll_counter += 1
|
||||
if gpu_poll_counter >= gpu_poll_ticks:
|
||||
gpu_poll_counter = 0
|
||||
if self._gpu_poller.available:
|
||||
self._state.gpus = self._gpu_poller.poll()
|
||||
|
||||
# Update elapsed time
|
||||
self._state.elapsed_seconds = (
|
||||
time.time() - self._state.start_time.timestamp()
|
||||
)
|
||||
|
||||
self._update_layout(layout)
|
||||
live.update(layout)
|
||||
|
||||
time.sleep(tick_interval)
|
||||
|
||||
# Final drain
|
||||
self._drain_queue()
|
||||
self._update_layout(layout)
|
||||
live.update(layout)
|
||||
finally:
|
||||
if self._io_capture:
|
||||
self._io_capture.stop()
|
||||
try:
|
||||
tty_file.close()
|
||||
except Exception:
|
||||
pass
|
||||
88
src/axolotl/tui/state.py
Normal file
88
src/axolotl/tui/state.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""TUI shared data model — dataclasses for the dashboard state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUStats:
|
||||
id: int
|
||||
name: str
|
||||
util_pct: float
|
||||
vram_used_gb: float
|
||||
vram_total_gb: float
|
||||
temp_c: int
|
||||
power_w: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogLine:
|
||||
timestamp: datetime
|
||||
level: str # "info" | "debug" | "warning" | "error"
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionSample:
|
||||
step: int
|
||||
prompt: str
|
||||
completion: str
|
||||
reward: float | None
|
||||
advantage: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TUIState:
|
||||
# Run metadata
|
||||
run_name: str = ""
|
||||
model_name: str = ""
|
||||
training_mode: str = "sft"
|
||||
world_size: int = 1
|
||||
start_time: datetime = field(default_factory=datetime.now)
|
||||
|
||||
# Progress
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
current_epoch: float = 0.0
|
||||
total_epochs: float = 1.0
|
||||
elapsed_seconds: float = 0.0
|
||||
eta_seconds: float | None = None
|
||||
|
||||
# Training metrics (rolling window + current)
|
||||
loss: float | None = None
|
||||
grad_norm: float | None = None
|
||||
learning_rate: float | None = None
|
||||
tokens_per_second: float | None = None
|
||||
samples_per_second: float | None = None
|
||||
mfu: float | None = None
|
||||
|
||||
# RL-specific (None for non-RL modes)
|
||||
rewards_mean: float | None = None
|
||||
rewards_std: float | None = None
|
||||
kl_divergence: float | None = None
|
||||
clip_ratio: float | None = None
|
||||
queue_size: int | None = None
|
||||
|
||||
# Per-GPU hardware (list indexed by local rank)
|
||||
gpus: list[GPUStats] = field(default_factory=list)
|
||||
|
||||
# Recent log lines
|
||||
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
|
||||
|
||||
# Recent completions (GRPO/SFT with log_completions)
|
||||
completions: deque[CompletionSample] = field(
|
||||
default_factory=lambda: deque(maxlen=20)
|
||||
)
|
||||
|
||||
# Loss history for sparkline
|
||||
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
|
||||
|
||||
# DeepSpeed zero stage (None if not using DeepSpeed)
|
||||
zero_stage: int | None = None
|
||||
|
||||
# Arbitrary plugin state
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -13,6 +13,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.utils.datasets import get_default_process_count
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
@@ -140,6 +141,12 @@ class AxolotlInputConfig(
|
||||
vllm: VllmConfig | None = Field(
|
||||
default_factory=lambda: VllmConfig(),
|
||||
)
|
||||
tui: TUIConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "TUI dashboard configuration. Set enabled: true to activate."
|
||||
},
|
||||
)
|
||||
qat: QATConfig | None = None
|
||||
quantization: PTQConfig | None = None
|
||||
reward_model: bool | None = Field(
|
||||
@@ -1385,39 +1392,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
if data.get("trust_remote_code"):
|
||||
return data
|
||||
|
||||
# Skip auto-enable for MoE models when native grouped_mm is unavailable
|
||||
# (torch < 2.9). The grouped_mm fallback in transformers uses torch.mm
|
||||
# with out= which bypasses autocast and fails on mixed dtypes during eval.
|
||||
env_capabilities = data.get("env_capabilities", {})
|
||||
torch_version = env_capabilities.get("torch_version")
|
||||
if torch_version is None:
|
||||
import torch
|
||||
|
||||
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||
has_grouped_mm = version.parse(torch_version) >= version.parse("2.9.0")
|
||||
if not has_grouped_mm:
|
||||
is_moe = False
|
||||
model_type = data.get("model_config_type", "")
|
||||
if model_type and "moe" in model_type.lower():
|
||||
is_moe = True
|
||||
if not is_moe:
|
||||
try:
|
||||
from transformers import AutoConfig
|
||||
|
||||
base_model = data.get("base_model")
|
||||
if base_model:
|
||||
auto_cfg = AutoConfig.from_pretrained(
|
||||
base_model, trust_remote_code=False
|
||||
)
|
||||
if getattr(auto_cfg, "num_local_experts", None) or getattr(
|
||||
auto_cfg, "num_experts", None
|
||||
):
|
||||
is_moe = True
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if is_moe:
|
||||
return data
|
||||
|
||||
# Check multi-GPU compatibility
|
||||
capabilities = data.get("capabilities")
|
||||
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
|
||||
|
||||
@@ -176,31 +176,24 @@ def test_lora_mlp_direct(sample_tensors, activation_forward, activation_backward
|
||||
X.requires_grad = True
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None, # gate_quant
|
||||
None, # gate_A
|
||||
None, # gate_B
|
||||
None, # gate_scale
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None, # up_quant
|
||||
None, # up_A
|
||||
None, # up_B
|
||||
None, # up_scale
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None, # down_quant
|
||||
None, # down_A
|
||||
None, # down_B
|
||||
None, # down_scale
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True, # inplace
|
||||
@@ -254,31 +247,24 @@ def test_lora_mlp_with_adapters(
|
||||
# Forward pass with adapters
|
||||
output = LoRA_MLP.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
gate_proj.weight,
|
||||
gate_proj.bias,
|
||||
None,
|
||||
gate_A,
|
||||
gate_B,
|
||||
scale,
|
||||
None, # gate_lora_bias
|
||||
None, # gate_magnitude
|
||||
up_proj.weight,
|
||||
up_proj.bias,
|
||||
None,
|
||||
up_A,
|
||||
up_B,
|
||||
scale,
|
||||
None, # up_lora_bias
|
||||
None, # up_magnitude
|
||||
down_proj.weight,
|
||||
down_proj.bias,
|
||||
None,
|
||||
down_A,
|
||||
down_B,
|
||||
scale,
|
||||
None, # down_lora_bias
|
||||
None, # down_magnitude
|
||||
activation_forward,
|
||||
activation_backward,
|
||||
True,
|
||||
@@ -348,32 +334,25 @@ def test_lora_qkv(sample_tensors):
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # Q: weight, bias, quant, A, B, scale, lora_bias, magnitude
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q1.shape == K1.shape == V1.shape == X.shape
|
||||
@@ -387,32 +366,25 @@ def test_lora_qkv(sample_tensors):
|
||||
# Test with LoRA adapters
|
||||
Q2, K2, V2 = LoRA_QKV.apply(
|
||||
X,
|
||||
None, # X_drop
|
||||
q_weight,
|
||||
None,
|
||||
None,
|
||||
q_A,
|
||||
q_B,
|
||||
scale,
|
||||
None,
|
||||
None, # Q
|
||||
k_weight,
|
||||
None,
|
||||
None,
|
||||
k_A,
|
||||
k_B,
|
||||
scale,
|
||||
None,
|
||||
None, # K
|
||||
v_weight,
|
||||
None,
|
||||
None,
|
||||
v_A,
|
||||
v_B,
|
||||
scale,
|
||||
None,
|
||||
None, # V
|
||||
True, # inplace
|
||||
True,
|
||||
)
|
||||
|
||||
assert Q2.shape == K2.shape == V2.shape == X.shape
|
||||
@@ -455,9 +427,7 @@ def test_lora_o(sample_tensors):
|
||||
|
||||
# Test forward pass
|
||||
X.requires_grad = True
|
||||
output = LoRA_O.apply(
|
||||
X, None, W, b, None, A, B, scale, None, None
|
||||
) # X_drop, ..., lora_bias, magnitude
|
||||
output = LoRA_O.apply(X, W, b, None, A, B, scale)
|
||||
|
||||
assert output.shape == (X.shape[0], X.shape[1], W.shape[0])
|
||||
|
||||
@@ -572,7 +542,6 @@ def test_inplace_operations(sample_tensors, apply_function):
|
||||
"down_proj": nn.Linear(shapes["out"], shapes["hidden"]).to(
|
||||
device="cuda", dtype=torch.float16
|
||||
),
|
||||
"training": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user