Compare commits

...

2 Commits
main ... textui

Author SHA1 Message Date
Wing Lian
db6af43f3b chore: lint 2026-03-23 04:54:00 +00:00
Wing Lian
35d06c8087 add textui 2026-03-23 04:54:00 +00:00
24 changed files with 1834 additions and 17 deletions

View File

@@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@@ -101,6 +102,7 @@ def train(
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
tui: bool = False,
**kwargs,
):
"""
@@ -118,6 +120,10 @@ def train(
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Handle --tui flag: set env var so subprocess workers pick it up
if tui:
os.environ["AXOLOTL_TUI"] = "1"
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher

View File

@@ -2,6 +2,7 @@
import gc
import os
import queue
from pathlib import Path
from typing import Union
@@ -34,22 +35,101 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
# Start TUI early (before data loading) so it captures preprocessing events
tui_renderer = None
tui_queue: queue.Queue | None = None
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
if is_rank_0:
from axolotl.train import _is_tui_enabled
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
if _is_tui_enabled(cfg):
import queue as _queue
del model, tokenizer, trainer
from axolotl.train import _get_tui_config
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
gc.collect()
tui_config_dict = _get_tui_config(cfg)
tui_config = (
TUIConfig(**tui_config_dict)
if isinstance(tui_config_dict, dict)
else tui_config_dict
)
tui_queue = _queue.Queue(maxsize=4096)
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
# Send initial run info
model_name = cfg.base_model or ""
training_mode = str(cfg.rl) if cfg.rl else "sft"
world_size = int(os.environ.get("WORLD_SIZE", 1))
try:
tui_queue.put_nowait(
{
"type": "run_info",
"model_name": model_name,
"training_mode": training_mode,
"world_size": world_size,
}
)
except _queue.Full:
pass
tui_renderer.start()
# Attach logging handler early
import logging
from axolotl.tui.callback import _TUILogHandler
_early_log_handler = _TUILogHandler(
tui_queue, min_level=tui_config.log_level
)
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to BOTH root and axolotl loggers because axolotl logger
# has propagate=False so root handler never sees axolotl.* messages
root_logger = logging.getLogger()
root_logger.addHandler(_early_log_handler)
axolotl_logger = logging.getLogger("axolotl")
axolotl_logger.addHandler(_early_log_handler)
# Stash refs on cfg so train() can reuse the renderer
cfg._tui_renderer = tui_renderer
cfg._tui_queue = tui_queue
cfg._tui_early_log_handler = _early_log_handler
try:
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
del model, tokenizer, trainer
gc.collect()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train_unload(cfg)
finally:
# If the TUI renderer started early but train() didn't get to stop it
# (e.g., error during data loading), clean up here
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
try:
if tui_queue is not None:
tui_queue.put_nowait({"type": "done"})
except queue.Full:
pass
tui_renderer.stop()
# Remove early log handler from both root and axolotl loggers
if hasattr(cfg, "_tui_early_log_handler"):
import logging
logging.getLogger().removeHandler(cfg._tui_early_log_handler)
logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):

View File

@@ -9,7 +9,6 @@ import os
import shutil
import signal
import sys
import typing
import weakref
from collections import OrderedDict
from contextlib import ExitStack
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
Trainer,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
@@ -554,6 +550,36 @@ def setup_model_and_trainer(
)
def _is_tui_enabled(cfg: DictDefault) -> bool:
"""Check if TUI is enabled via config or environment variable."""
if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"):
return True
tui = cfg.get("tui")
if tui is None:
return False
if isinstance(tui, bool):
return tui
if isinstance(tui, dict):
return tui.get("enabled", False)
if hasattr(tui, "enabled"):
return tui.enabled
return False
def _get_tui_config(cfg: DictDefault) -> dict:
"""Extract TUI config dict from cfg."""
tui = cfg.get("tui")
if tui is None or isinstance(tui, bool):
return {"enabled": True}
if isinstance(tui, dict):
return {**tui, "enabled": True}
if hasattr(tui, "model_dump"):
d = tui.model_dump()
d["enabled"] = True
return d
return {"enabled": True}
@send_errors
def train(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
@@ -577,6 +603,37 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
# Register TUI callback if enabled and rank 0
tui_enabled = _is_tui_enabled(cfg)
if tui_enabled and cfg.local_rank == 0:
from axolotl.tui import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
tui_config = _get_tui_config(cfg)
tui_config_obj = (
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
)
# Reuse the early-started renderer if available (started in do_train)
early_renderer = getattr(cfg, "_tui_renderer", None)
early_queue = getattr(cfg, "_tui_queue", None)
tui_callback = AxolotlTUICallback(config=tui_config_obj)
if early_renderer is not None and early_queue is not None:
# Reuse the already-running renderer and queue
tui_callback._renderer = early_renderer
tui_callback._queue = early_queue
tui_callback._renderer_started_early = True
trainer.add_callback(tui_callback)
# Stash model info so on_train_begin can emit a single unified run_info event
tui_callback._pending_run_info = {
"model_name": cfg.base_model or "",
"training_mode": str(cfg.rl) if cfg.rl else "sft",
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
}
LOG.info("TUI dashboard enabled")
# Handle untrained tokens if configured
train_dataset = dataset_meta.train_dataset
handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset)

View File

@@ -0,0 +1,17 @@
"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs."""
from axolotl.tui.callback import AxolotlTUICallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.io_capture import LineParser, register_parser
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
__all__ = [
"AxolotlTUICallback",
"BasePanel",
"LineParser",
"TUIConfig",
"TUIState",
"register_panel",
"register_parser",
]

142
src/axolotl/tui/callback.py Normal file
View File

@@ -0,0 +1,142 @@
"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI."""
from __future__ import annotations
import logging
import queue
from transformers.trainer_callback import TrainerCallback
from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer
class _TUILogHandler(logging.Handler):
"""Logging handler that pushes log records into the TUI metric queue."""
_LEVEL_MAP = {
logging.DEBUG: "debug",
logging.INFO: "info",
logging.WARNING: "warning",
logging.ERROR: "error",
logging.CRITICAL: "error",
}
def __init__(self, metric_queue: queue.Queue, min_level: str = "info"):
super().__init__()
level_name = min_level.upper()
self.setLevel(getattr(logging, level_name, logging.INFO))
self._queue = metric_queue
def emit(self, record: logging.LogRecord) -> None:
try:
level = self._LEVEL_MAP.get(record.levelno, "info")
msg = self.format(record)
self._queue.put_nowait(
{
"type": "log_line",
"level": level,
"message": msg,
}
)
except queue.Full:
pass
except Exception:
self.handleError(record)
class AxolotlTUICallback(TrainerCallback):
"""Pushes training metrics into a queue for the TUI renderer.
The callback never blocks on the render thread. The queue is bounded
(maxsize=512) with put_nowait; overflow is silently dropped.
"""
def __init__(self, config: TUIConfig):
self._config = config
self._queue: queue.Queue = queue.Queue(maxsize=4096)
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
self._log_handler: _TUILogHandler | None = None
self._renderer_started_early: bool = False
self._pending_run_info: dict | None = None
def _put(self, event: dict) -> None:
try:
self._queue.put_nowait(event)
except queue.Full:
pass
def on_train_begin(self, args, state, control, model=None, **kwargs):
# Send a single unified run_info event with all fields
run_info = {
"type": "run_info",
"run_name": getattr(args, "run_name", "") or "",
"total_steps": state.max_steps,
"total_epochs": float(args.num_train_epochs)
if args.num_train_epochs
else 1.0,
}
# Merge in model_name/training_mode/world_size if stashed by train.py
if self._pending_run_info:
run_info.update(self._pending_run_info)
self._pending_run_info = None
self._put(run_info)
if not self._renderer_started_early:
# Attach a logging handler to feed log messages into the events panel
self._log_handler = _TUILogHandler(
self._queue, min_level=self._config.log_level
)
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to both root and axolotl loggers (axolotl has propagate=False)
logging.getLogger().addHandler(self._log_handler)
logging.getLogger("axolotl").addHandler(self._log_handler)
# Start the renderer background thread
self._renderer.start()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Filter out non-numeric keys and internal keys
filtered = {}
for key, value in logs.items():
if key.startswith("_"):
continue
if isinstance(value, (int, float)):
filtered[key] = value
elif isinstance(value, str):
# HF Trainer sometimes passes string-encoded numbers
try:
filtered[key] = float(value)
except (ValueError, TypeError):
pass
if filtered:
self._put({"type": "metrics", "logs": filtered})
def on_step_end(self, args, state, control, **kwargs):
self._put(
{
"type": "step",
"step": state.global_step,
"total_steps": state.max_steps,
"epoch": state.epoch if state.epoch else 0,
}
)
def on_prediction_step(self, args, state, control, **kwargs):
pass
def on_train_end(self, args, state, control, **kwargs):
self._put({"type": "done"})
# If renderer was started early, do_train's finally block handles stop
if not self._renderer_started_early:
self._renderer.stop()
# Remove the logging handler (only if we added it)
if self._log_handler:
logging.getLogger().removeHandler(self._log_handler)
logging.getLogger("axolotl").removeHandler(self._log_handler)
self._log_handler = None

38
src/axolotl/tui/config.py Normal file
View File

@@ -0,0 +1,38 @@
"""TUI configuration — Pydantic model for TUI settings."""
from __future__ import annotations
from pydantic import BaseModel, Field
class TUIConfig(BaseModel):
"""Configuration for the Axolotl Training TUI dashboard."""
enabled: bool = Field(
default=False,
json_schema_extra={"description": "Enable the TUI dashboard"},
)
refresh_rate: int = Field(
default=4,
json_schema_extra={"description": "Renders per second"},
)
log_level: str = Field(
default="debug",
json_schema_extra={"description": "Minimum log level shown in events panel"},
)
panels: list[str] = Field(
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
json_schema_extra={"description": "Ordered list of panels to display"},
)
hardware_poll_interval: int = Field(
default=2,
json_schema_extra={"description": "Seconds between pynvml GPU queries"},
)
stdout_log_path: str = Field(
default="axolotl_stdout.log",
json_schema_extra={"description": "File path for captured stdout/stderr log"},
)
parser_plugins: list[str] = Field(
default_factory=list,
json_schema_extra={"description": "List of extra parser classes to load"},
)

72
src/axolotl/tui/gpu.py Normal file
View File

@@ -0,0 +1,72 @@
"""GPU polling wrapper around pynvml with graceful fallback."""
from __future__ import annotations
import logging
from axolotl.tui.state import GPUStats
LOG = logging.getLogger(__name__)
_nvml_available = False
try:
import pynvml
pynvml.nvmlInit()
_nvml_available = True
except Exception:
LOG.debug("pynvml unavailable — GPU stats will not be shown")
class GPUPoller:
"""Polls local GPU stats via pynvml. Falls back gracefully if unavailable."""
def __init__(self):
self._device_count = 0
if _nvml_available:
try:
self._device_count = pynvml.nvmlDeviceGetCount()
except Exception:
self._device_count = 0
@property
def available(self) -> bool:
return _nvml_available and self._device_count > 0
def poll(self) -> list[GPUStats]:
if not self.available:
return []
stats = []
for i in range(self._device_count):
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
if isinstance(name, bytes):
name = name.decode("utf-8")
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
try:
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0
except Exception:
power = None
stats.append(
GPUStats(
id=i,
name=name,
util_pct=util.gpu,
vram_used_gb=mem.used / (1024**3),
vram_total_gb=mem.total / (1024**3),
temp_c=temp,
power_w=power,
)
)
except Exception:
LOG.debug("Error polling GPU device %d", i, exc_info=True)
return stats

View File

@@ -0,0 +1,196 @@
"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry."""
from __future__ import annotations
import logging
import os
import queue
import sys
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from typing import IO
# ---------------------------------------------------------------------------
# Parser registry
# ---------------------------------------------------------------------------
_parser_registry: list[type[LineParser]] = []
def register_parser(cls: type[LineParser]) -> type[LineParser]:
"""Decorator to register a LineParser subclass."""
if cls not in _parser_registry:
_parser_registry.append(cls)
return cls
def get_registered_parsers() -> list[type[LineParser]]:
return list(_parser_registry)
# ---------------------------------------------------------------------------
# Base LineParser
# ---------------------------------------------------------------------------
class LineParser(ABC):
"""Base class for stdout/stderr line parsers."""
priority: int = 50
name: str = ""
@abstractmethod
def parse(self, line: str, source: str) -> list[dict]:
"""Parse a single captured line.
Args:
line: one line of captured output, trailing newline stripped.
source: "stdout" or "stderr".
Returns:
List of event dicts to push onto the metric queue.
Return [] if this line is not relevant.
"""
...
# ---------------------------------------------------------------------------
# ParserChain
# ---------------------------------------------------------------------------
class ParserChain:
def __init__(self):
self._parsers: list[LineParser] = []
def register(self, parser: LineParser) -> None:
self._parsers.append(parser)
self._parsers.sort(key=lambda p: p.priority)
def parse(self, line: str, source: str = "stdout") -> list[dict]:
events: list[dict] = []
for parser in self._parsers:
events.extend(parser.parse(line, source))
return events
# ---------------------------------------------------------------------------
# IOCapture — OS-level fd redirect to pipe
# ---------------------------------------------------------------------------
class IOCapture:
"""Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread,
passes lines through a ParserChain, and tees to a log file."""
def __init__(
self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue
):
self._parser_chain = parser_chain
self._queue = metric_queue
self._log_path = log_path
self._log_file: IO[str] | None = None
self._thread: threading.Thread | None = None
self._read_fd: int | None = None
self._write_fd: int | None = None
self._saved_stdout_fd: int | None = None
self._saved_stderr_fd: int | None = None
def start(self) -> None:
# Write run-start separator
self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115
self._log_file.write(
f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n"
)
self._log_file.flush()
# OS-level pipe
self._read_fd, self._write_fd = os.pipe()
# Save originals
self._saved_stdout_fd = os.dup(1)
self._saved_stderr_fd = os.dup(2)
# Redirect both stdout and stderr into the write end
os.dup2(self._write_fd, 1)
os.dup2(self._write_fd, 2)
os.close(self._write_fd) # write end now held by fds 1 and 2
# Also redirect Python-level handles
sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115
sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115
# Drain thread
self._thread = threading.Thread(target=self._drain, daemon=True)
self._thread.start()
def stop(self) -> None:
# Restore fds — closes the write end, causing reader to see EOF
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
os.dup2(self._saved_stdout_fd, 1)
os.dup2(self._saved_stderr_fd, 2)
os.close(self._saved_stdout_fd)
os.close(self._saved_stderr_fd)
self._saved_stdout_fd = None
self._saved_stderr_fd = None
if self._thread is not None:
self._thread.join(timeout=2.0)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
"IO capture thread did not exit after 2s"
)
self._thread = None
if self._log_file is not None:
self._log_file.close()
self._log_file = None
def _drain(self) -> None:
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
# which use \r for in-place updates without \n
assert self._read_fd is not None, "_drain called before start()"
with os.fdopen(self._read_fd, "rb") as pipe:
buf = b""
while True:
chunk = pipe.read(4096)
if not chunk:
# EOF — process remaining buffer
if buf:
self._process_line(buf.decode("utf-8", errors="replace"))
break
buf += chunk
# Split on \n or \r
while b"\n" in buf or b"\r" in buf:
# Find the earliest delimiter
idx_n = buf.find(b"\n")
idx_r = buf.find(b"\r")
if idx_n == -1:
idx = idx_r
elif idx_r == -1:
idx = idx_n
else:
idx = min(idx_n, idx_r)
line = buf[:idx].decode("utf-8", errors="replace")
buf = buf[idx + 1 :]
# Handle \r\n as single delimiter
if buf.startswith(b"\n"):
buf = buf[1:]
if line:
self._process_line(line)
def _process_line(self, line: str) -> None:
line = line.rstrip()
if not line:
return
if self._log_file:
self._log_file.write(line + "\n")
self._log_file.flush()
for event in self._parser_chain.parse(line):
try:
self._queue.put_nowait(event)
except queue.Full:
pass

View File

@@ -0,0 +1,63 @@
"""Panel registry and base class for TUI panels."""
from __future__ import annotations
from abc import ABC, abstractmethod
from rich.console import RenderableType
from axolotl.tui.state import TUIState
# ---------------------------------------------------------------------------
# Panel registry
# ---------------------------------------------------------------------------
_panel_registry: dict[str, type[BasePanel]] = {}
def register_panel(position: str = "bottom", weight: int = 50):
"""Decorator to register a panel class with position and weight."""
def decorator(cls: type[BasePanel]) -> type[BasePanel]:
cls.position = position
cls.weight = weight
_panel_registry[cls.name] = cls
return cls
return decorator
def get_registered_panels() -> dict[str, type[BasePanel]]:
return dict(_panel_registry)
# ---------------------------------------------------------------------------
# BasePanel
# ---------------------------------------------------------------------------
class BasePanel(ABC):
name: str = ""
position: str = "bottom"
weight: int = 50
min_height: int = 4
max_height: int | None = None
modes: list[str] = ["*"]
@abstractmethod
def render(self, state: TUIState) -> RenderableType:
"""Return a rich renderable. Called every tick."""
...
def on_event(self, event: dict) -> None: # noqa: B027
"""Optional: react to raw metric events before state is merged."""
pass
# Auto-import built-in panels to trigger registration
from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401
from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401
from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401
from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401
from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401
from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401

View File

@@ -0,0 +1,61 @@
"""CompletionsPanel — shows recent RL/log_completions samples."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _truncate(s: str, maxlen: int = 60) -> str:
return s[:maxlen] + "" if len(s) > maxlen else s
@register_panel(position="bottom", weight=20)
class CompletionsPanel(BasePanel):
name = "completions"
min_height = 6
modes = ["grpo", "dpo"]
def render(self, state: TUIState) -> RenderableType:
if "*" not in self.modes and state.training_mode not in self.modes:
return Text("")
if not state.completions:
return Panel(
Text("No completions yet...", style="dim"),
title="Completions",
border_style="magenta",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("step", justify="right", width=6)
table.add_column("prompt", no_wrap=False, max_width=40)
table.add_column("completion", no_wrap=False, max_width=40)
table.add_column("reward", justify="right", width=8)
table.add_column("adv", justify="right", width=8)
for sample in list(state.completions)[-5:]:
reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--"
adv_str = (
f"{sample.advantage:+.2f}" if sample.advantage is not None else "--"
)
table.add_row(
str(sample.step),
_truncate(sample.prompt),
_truncate(sample.completion),
reward_str,
adv_str,
)
return Panel(table, title="Completions", border_style="magenta")

View File

@@ -0,0 +1,34 @@
"""DebugPanel — scrolling log of debug-level messages, separate from main events."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
@register_panel(position="bottom", weight=30)
class DebugPanel(BasePanel):
name = "debug"
min_height = 6
max_height = 10
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 8 debug-level log lines
debug_lines = [
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
][-8:]
for log_line in debug_lines:
ts = log_line.timestamp.strftime("%H:%M:%S")
lines.append(f"[{ts}] ", style="dim")
lines.append(log_line.message[:200], style="dim")
lines.append("\n")
if not debug_lines:
lines = Text("No debug messages yet...", style="dim")
return Panel(lines, title="Debug", border_style="dim")

View File

@@ -0,0 +1,45 @@
"""EventsPanel — scrolling log of recent events, color-coded by level."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_LEVEL_STYLES = {
"debug": "dim",
"info": "",
"warning": "yellow",
"error": "red bold",
"critical": "red bold",
}
@register_panel(position="bottom", weight=10)
class EventsPanel(BasePanel):
name = "events"
min_height = 8
max_height = 20
def render(self, state: TUIState) -> RenderableType:
lines = Text()
# Show last 15 non-debug log lines (debug goes to DebugPanel)
recent = [
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
][-15:]
for log_line in recent:
ts = log_line.timestamp.strftime("%H:%M:%S")
level = log_line.level.upper()
style = _LEVEL_STYLES.get(log_line.level, "")
lines.append(f"[{ts}] ", style="dim")
lines.append(f"[{level}] ", style=style or "")
lines.append(log_line.message[:200], style=style or "")
lines.append("\n")
if not recent:
lines = Text("No events yet...", style="dim")
return Panel(lines, title="Events", border_style="yellow")

View File

@@ -0,0 +1,80 @@
"""HardwarePanel — per-GPU stats via pynvml."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
_BAR_FULL = ""
_BAR_EMPTY = ""
def _util_bar(pct: float, width: int = 6) -> Text:
filled = int(pct / 100 * width)
bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled)
color = "green" if pct < 70 else ("yellow" if pct < 90 else "red")
return Text.assemble((bar, color), f" {pct:3.0f}%")
@register_panel(position="right", weight=10)
class HardwarePanel(BasePanel):
name = "hardware"
min_height = 6
def render(self, state: TUIState) -> RenderableType:
if not state.gpus:
return Panel(
Text("GPU stats unavailable", style="dim"),
title="Hardware",
border_style="green",
)
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("id", justify="right", width=3)
table.add_column("util", no_wrap=True)
table.add_column("vram", no_wrap=True)
table.add_column("°C", justify="right", width=4)
table.add_column("W", justify="right", width=5)
total_vram_used = 0.0
total_vram_total = 0.0
total_util = 0.0
for gpu in state.gpus:
total_vram_used += gpu.vram_used_gb
total_vram_total += gpu.vram_total_gb
total_util += gpu.util_pct
power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--"
table.add_row(
str(gpu.id),
_util_bar(gpu.util_pct),
f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB",
str(gpu.temp_c),
power_str,
)
# Footer with aggregates
n = len(state.gpus)
if n > 1:
avg_util = total_util / n
table.add_row(
"Σ",
Text(f"avg {avg_util:.0f}%", style="dim"),
Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"),
"",
"",
)
return Panel(table, title="Hardware", border_style="green")

View File

@@ -0,0 +1,73 @@
"""ProgressPanel — top-bar progress display with step count, elapsed, ETA."""
from __future__ import annotations
from rich.console import RenderableType
from rich.progress import BarColumn, Progress, TextColumn
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
def _fmt_time(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "--:--:--"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
s = int(seconds) % 60
return f"{h}:{m:02d}:{s:02d}"
def _fmt_eta(seconds: float | None) -> str:
if seconds is None or seconds < 0:
return "eta --"
h = int(seconds) // 3600
m = (int(seconds) % 3600) // 60
if h > 0:
return f"eta {h}h{m:02d}m"
return f"eta {m}m{int(seconds) % 60:02d}s"
@register_panel(position="top", weight=10)
class ProgressPanel(BasePanel):
name = "progress"
min_height = 3
max_height = 3
def render(self, state: TUIState) -> RenderableType:
pct = (
(state.current_step / state.total_steps * 100)
if state.total_steps > 0
else 0
)
# Header line
mode_upper = state.training_mode.upper() if state.training_mode else "SFT"
model_short = state.model_name.split("/")[-1] if state.model_name else "model"
header = Text.assemble(
("", "bold green"),
("AXOLOTL", "bold cyan"),
f" {mode_upper} · {model_short} ",
(
f"{state.current_step} / {state.total_steps}",
"bold",
),
f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%",
)
# Progress bar
progress = Progress(
TextColumn(""),
BarColumn(bar_width=None),
TextColumn("{task.percentage:>3.0f}%"),
expand=True,
)
task = progress.add_task("", total=state.total_steps or 1)
progress.update(task, completed=state.current_step)
table = Table.grid(expand=True)
table.add_row(header)
table.add_row(progress)
return table

View File

@@ -0,0 +1,97 @@
"""TrainingPanel — live scalar metrics table with loss sparkline."""
from __future__ import annotations
from rich.console import RenderableType
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import TUIState
# Braille sparkline characters (8 levels)
_SPARK_CHARS = "▁▂▃▄▅▆▇█"
def _sparkline(values: list[float] | None, width: int = 20) -> str:
if not values or len(values) < 2:
return ""
vals = list(values)[-width:]
lo, hi = min(vals), max(vals)
rng = hi - lo if hi != lo else 1.0
return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals)
# Known key ordering and formatting
_KNOWN_KEYS: list[tuple[str, str, str]] = [
("loss", "loss", ".4f"),
("grad_norm", "grad norm", ".3f"),
("learning_rate", "lr", ".2e"),
("tokens_per_second", "tok/s", ".1f"),
("samples_per_second", "samples/s", ".1f"),
("mfu", "MFU", ".1f"),
# RL-specific
("rewards_mean", "rewards/mean", ".4f"),
("rewards_std", "rewards/std", ".4f"),
("kl_divergence", "KL", ".4f"),
("clip_ratio", "clip ratio", ".3f"),
("queue_size", "queue", "d"),
]
@register_panel(position="left", weight=10)
class TrainingPanel(BasePanel):
name = "training"
min_height = 8
def render(self, state: TUIState) -> RenderableType:
table = Table(
show_header=True,
header_style="bold",
expand=True,
box=None,
pad_edge=False,
)
table.add_column("metric", style="cyan", no_wrap=True)
table.add_column("value", justify="right")
table.add_column("trend", justify="left", no_wrap=True)
for attr, label, fmt in _KNOWN_KEYS:
val = getattr(state, attr, None)
if val is None:
# Also check extra dict
val = state.extra.get(attr)
if val is None:
continue
try:
formatted = f"{val:{fmt}}"
except (ValueError, TypeError):
formatted = str(val)
trend = ""
if attr == "loss":
trend = _sparkline(list(state.loss_history))
table.add_row(label, formatted, trend)
# Any extra keys not in _KNOWN_KEYS
known_attrs = {k for k, _, _ in _KNOWN_KEYS}
for key, val in sorted(state.extra.items()):
if key in known_attrs or val is None:
continue
try:
formatted = f"{val:.4f}"
except (ValueError, TypeError):
formatted = str(val)
table.add_row(key, formatted, "")
if table.row_count == 0:
return Panel(
Text("Waiting for first log step...", style="dim"),
title="Training",
border_style="blue",
)
return Panel(table, title="Training", border_style="blue")

View File

@@ -0,0 +1,7 @@
"""Built-in line parsers — auto-imported to trigger @register_parser decorators."""
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401

View File

@@ -0,0 +1,29 @@
"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class DeepSpeedParser(LineParser):
priority = 20
name = "deepspeed"
_SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)")
_STAGE_RE = re.compile(r"ZeRO Stage (\d)")
def parse(self, line: str, source: str) -> list[dict]:
events: list[dict] = []
if m := self._SAMPLES_RE.search(line):
events.append(
{
"type": "metrics",
"logs": {"samples_per_second": float(m.group(1))},
}
)
if m := self._STAGE_RE.search(line):
events.append({"type": "run_info", "zero_stage": int(m.group(1))})
return events

View File

@@ -0,0 +1,27 @@
"""NCCLErrorParser — surfaces NCCL errors as red alert events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class NCCLErrorParser(LineParser):
priority = 10
name = "nccl_error"
_RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "error",
"message": f"⚠ NCCL: {line}",
},
{"type": "alert", "severity": "error", "message": line},
]
return []

View File

@@ -0,0 +1,37 @@
"""RawLogParser — catches every line as a log_line event."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class RawLogParser(LineParser):
priority = 99
name = "raw_log"
_LOG_RE = re.compile(
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
r"\s*[-]\s*(?P<msg>.+)$",
re.IGNORECASE,
)
# Filter out tqdm progress bar lines and other noisy output
_TQDM_RE = re.compile(r"^\s*\d+%\|.*\|")
_EMPTY_RE = re.compile(r"^\s*$")
def parse(self, line: str, source: str) -> list[dict]:
# Skip empty lines and tqdm progress bar updates
if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line):
return []
m = self._LOG_RE.match(line)
level = (
m.group("level").lower()
if m
else ("error" if source == "stderr" else "info")
)
return [{"type": "log_line", "level": level, "message": line}]

View File

@@ -0,0 +1,26 @@
"""TorchCompileParser — detects torch.compile graph breaks and recompilations."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TorchCompileParser(LineParser):
priority = 20
name = "torch_compile"
_RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE)
def parse(self, line: str, source: str) -> list[dict]:
if self._RE.search(line):
return [
{
"type": "log_line",
"level": "warning",
"message": f"⚡ compile: {line}",
}
]
return []

View File

@@ -0,0 +1,86 @@
"""TqdmParser — captures tqdm progress bar output and surfaces as structured events."""
from __future__ import annotations
import re
from axolotl.tui.io_capture import LineParser, register_parser
@register_parser
class TqdmParser(LineParser):
priority = 15
name = "tqdm"
# Match tqdm-style progress lines, e.g.:
# Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s]
# Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s]
# 0%| | 0/30 [00:00<?, ?it/s]
_TQDM_RE = re.compile(
r"(?P<desc>.*?)\s*"
r"(?P<pct>\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*"
r"(?P<current>[\d,]+)/(?P<total>[\d,]+)"
r"\s*\[(?P<elapsed>[^\]]*)\]"
)
# Also match simpler forms like:
# Fetching 0 files: 0it [00:00, ?it/s]
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
def parse(self, line: str, source: str) -> list[dict]:
m = self._TQDM_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
pct = int(m.group("pct"))
current = int(m.group("current").replace(",", ""))
total = int(m.group("total").replace(",", ""))
events: list[dict] = []
# Surface as a log line with progress info
if pct == 100 or pct == 0 or pct % 25 == 0:
msg = (
f"[{desc}] {pct}% ({current}/{total})"
if desc
else f"{pct}% ({current}/{total})"
)
events.append(
{
"type": "log_line",
"level": "info",
"message": msg,
}
)
# Also emit as a progress metric
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "progress"
events.append(
{
"type": "metrics",
"logs": {
f"progress/{cleaned_desc}": pct / 100.0,
},
}
)
return events
# Fallback: try simpler fetch-style progress lines
m = self._FETCH_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
current = int(m.group("current"))
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "fetch"
return [
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {current}" if desc else f"{current}",
}
]
return []

449
src/axolotl/tui/renderer.py Normal file
View File

@@ -0,0 +1,449 @@
"""TUIRenderer — background daemon thread that drives the rich.live.Live display."""
from __future__ import annotations
import logging
import queue
import threading
import time
from datetime import datetime
from typing import Any
from rich.console import Console
from rich.layout import Layout
from rich.live import Live
from axolotl.tui.config import TUIConfig
from axolotl.tui.gpu import GPUPoller
from axolotl.tui.io_capture import (
IOCapture,
ParserChain,
get_registered_parsers,
)
from axolotl.tui.panels import BasePanel, get_registered_panels
from axolotl.tui.state import CompletionSample, LogLine, TUIState
LOG = logging.getLogger(__name__)
class TUIRenderer:
"""Background thread that renders the TUI dashboard using rich.live.Live."""
def __init__(self, config: TUIConfig, metric_queue: queue.Queue):
self._config = config
self._queue = metric_queue
self._state = TUIState()
self._gpu_poller = GPUPoller()
self._panels: list[BasePanel] = []
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._io_capture: IOCapture | None = None
self._parser_chain: ParserChain | None = None
def _init_panels(self) -> None:
registry = get_registered_panels()
for panel_name in self._config.panels:
if panel_name in registry:
self._panels.append(registry[panel_name]())
def _init_parser_chain(self) -> None:
# Ensure built-in parsers are imported so @register_parser decorators fire
import axolotl.tui.parsers # noqa: F401
self._parser_chain = ParserChain()
# Register all built-in parsers
for parser_cls in get_registered_parsers():
self._parser_chain.register(parser_cls())
# Load plugin parsers
for plugin_spec in self._config.parser_plugins:
try:
if "::" in plugin_spec:
# file path :: class name
file_path, class_name = plugin_spec.split("::", 1)
import importlib.util
spec = importlib.util.spec_from_file_location(
"custom_parser", file_path
)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec for {file_path}")
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
parser_cls = getattr(mod, class_name)
else:
# dotted module path
module_path, class_name = plugin_spec.rsplit(".", 1)
mod = importlib.import_module(module_path)
parser_cls = getattr(mod, class_name)
self._parser_chain.register(parser_cls())
except Exception as exc:
LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}")
def _build_layout(self) -> Layout:
layout = Layout()
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
sections = []
if top_panels:
layout_top = Layout(name="top", size=3)
sections.append(layout_top)
if left_panels or right_panels:
layout_middle = Layout(name="middle", ratio=3)
middle_parts = []
if left_panels:
middle_parts.append(Layout(name="left", ratio=1))
if right_panels:
middle_parts.append(Layout(name="right", ratio=1))
if middle_parts:
layout_middle.split_row(*middle_parts)
sections.append(layout_middle)
if bottom_panels:
layout_bottom = Layout(name="bottom", ratio=2)
if len(bottom_panels) > 1:
layout_bottom.split_row(
*[
Layout(name=f"bottom_{i}", ratio=1)
for i in range(len(bottom_panels))
]
)
sections.append(layout_bottom)
if sections:
layout.split_column(*sections)
return layout
def _update_layout(self, layout: Layout) -> None:
top_panels = [p for p in self._panels if p.position == "top"]
left_panels = [p for p in self._panels if p.position == "left"]
right_panels = [p for p in self._panels if p.position == "right"]
bottom_panels = [p for p in self._panels if p.position == "bottom"]
if top_panels:
layout["top"].update(top_panels[0].render(self._state))
if left_panels:
layout["left"].update(left_panels[0].render(self._state))
if right_panels:
layout["right"].update(right_panels[0].render(self._state))
if bottom_panels:
if len(bottom_panels) == 1:
layout["bottom"].update(bottom_panels[0].render(self._state))
else:
for i, panel in enumerate(bottom_panels):
layout[f"bottom_{i}"].update(panel.render(self._state))
def _drain_queue(self) -> None:
while True:
try:
event = self._queue.get_nowait()
except queue.Empty:
break
# Dispatch event to panels first
for panel in self._panels:
panel.on_event(event)
event_type = event.get("type")
if event_type == "metrics":
logs = event.get("logs", {})
self._apply_metrics(logs)
elif event_type == "step":
self._state.current_step = event.get("step", self._state.current_step)
self._state.total_steps = event.get(
"total_steps", self._state.total_steps
)
self._state.current_epoch = event.get(
"epoch", self._state.current_epoch
)
now = time.time()
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
if self._state.current_step > 0 and self._state.total_steps > 0:
rate = self._state.elapsed_seconds / self._state.current_step
remaining = self._state.total_steps - self._state.current_step
self._state.eta_seconds = rate * remaining
elif event_type == "log_line":
level = event.get("level", "info")
message = event.get("message", "")
self._state.log_lines.append(
LogLine(
timestamp=datetime.now(),
level=level,
message=message,
)
)
elif event_type == "completion":
self._state.completions.append(
CompletionSample(
step=event.get("step", 0),
prompt=event.get("prompt", ""),
completion=event.get("completion", ""),
reward=event.get("reward"),
advantage=event.get("advantage"),
)
)
elif event_type == "run_info":
if "run_name" in event:
self._state.run_name = event["run_name"]
if "model_name" in event:
self._state.model_name = event["model_name"]
if "training_mode" in event:
self._state.training_mode = event["training_mode"]
if "world_size" in event:
self._state.world_size = event["world_size"]
if "total_steps" in event:
self._state.total_steps = event["total_steps"]
if "total_epochs" in event:
self._state.total_epochs = event["total_epochs"]
if "zero_stage" in event:
self._state.zero_stage = event["zero_stage"]
elif event_type == "done":
self._stop_event.set()
def _apply_metrics(self, logs: dict[str, Any]) -> None:
metric_map = {
"loss": "loss",
"grad_norm": "grad_norm",
"learning_rate": "learning_rate",
"tokens_per_second": "tokens_per_second",
"samples_per_second": "samples_per_second",
"mfu": "mfu",
"rewards/mean": "rewards_mean",
"rewards_mean": "rewards_mean",
"rewards/std": "rewards_std",
"rewards_std": "rewards_std",
"kl": "kl_divergence",
"kl_divergence": "kl_divergence",
"clip_ratio": "clip_ratio",
"queue_size": "queue_size",
}
for key, value in logs.items():
if key in metric_map:
setattr(self._state, metric_map[key], value)
else:
self._state.extra[key] = value
if "loss" in logs and logs["loss"] is not None:
self._state.loss_history.append(logs["loss"])
def start(self) -> None:
self._init_panels()
self._init_parser_chain()
# Set up I/O capture
assert self._parser_chain is not None, "_init_parser_chain must be called first"
self._io_capture = IOCapture(
log_path=self._config.stdout_log_path,
parser_chain=self._parser_chain,
metric_queue=self._queue,
)
# Monkeypatch tqdm to suppress terminal output and route through our queue.
# This prevents tqdm progress bars from flickering through the TUI and
# ensures all progress events appear in the Events panel.
self._install_tqdm_hook()
self._io_capture_ready = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._io_capture_ready.wait(timeout=5.0)
def _install_tqdm_hook(self) -> None:
"""Replace tqdm's display method to route updates through TUI queue."""
try:
import io
import tqdm
import tqdm.auto
q = self._queue
self._tqdm_parser = None
# Find our tqdm parser in the chain
for p in self._parser_chain._parsers if self._parser_chain else []:
if p.name == "tqdm":
self._tqdm_parser = p
break
# Save originals for restore
self._orig_tqdm_class_auto = tqdm.auto.tqdm
self._orig_tqdm_class_tqdm = tqdm.tqdm
self._orig_tqdm_class_std = tqdm.std.tqdm
class TUITqdm(tqdm.tqdm):
"""tqdm subclass that sends progress to TUI instead of terminal."""
def __init__(self, *args, **kwargs):
# Force output to devnull so nothing reaches the terminal
kwargs["file"] = io.StringIO()
kwargs["dynamic_ncols"] = False
kwargs["ncols"] = 80
super().__init__(*args, **kwargs)
def display(self, msg=None, pos=None):
# Build a progress string and push to queue
if self.total and self.total > 0:
pct = self.n / self.total * 100
desc = self.desc.rstrip(": ") if self.desc else ""
# Emit events at milestones or at low frequency
is_milestone = (
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
)
if is_milestone:
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
if desc
else f"{pct:.0f}% ({self.n}/{self.total})",
}
)
except Exception:
pass
try:
metric_key = (
f"progress/{desc.lower().replace(' ', '_')}"
if desc
else "progress/unknown"
)
q.put_nowait(
{
"type": "metrics",
"logs": {metric_key: pct / 100.0},
}
)
except Exception:
pass
def close(self):
# Emit final completion event
if self.total and self.total > 0 and self.n > 0:
desc = self.desc.rstrip(": ") if self.desc else ""
try:
q.put_nowait(
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
if desc
else f"100% ({self.total}/{self.total}) done",
}
)
except Exception:
pass
super().close()
# Replace tqdm globally
tqdm.auto.tqdm = TUITqdm
tqdm.tqdm = TUITqdm
# Also patch tqdm.std which some libraries use directly
tqdm.std.tqdm = TUITqdm
self._tui_tqdm_cls = TUITqdm
except Exception as exc:
LOG.debug(f"Failed to install tqdm hook: {exc}")
def _uninstall_tqdm_hook(self) -> None:
"""Restore original tqdm."""
try:
import tqdm
import tqdm.auto
if hasattr(self, "_orig_tqdm_class_auto"):
tqdm.auto.tqdm = self._orig_tqdm_class_auto
if hasattr(self, "_orig_tqdm_class_tqdm"):
tqdm.tqdm = self._orig_tqdm_class_tqdm
if hasattr(self, "_orig_tqdm_class_std"):
tqdm.std.tqdm = self._orig_tqdm_class_std
except Exception:
pass
def stop(self) -> None:
self._stop_event.set()
self._uninstall_tqdm_hook()
if self._thread is not None:
self._thread.join(timeout=5.0)
def _run(self) -> None:
import os
# Save a handle to the REAL terminal BEFORE IO capture redirects fds.
# This ensures rich.live.Live writes to the terminal, not the pipe.
saved_tty_fd = os.dup(1)
tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True)
console = Console(file=tty_file)
layout = self._build_layout()
tick_interval = 1.0 / max(self._config.refresh_rate, 1)
gpu_poll_counter = 0
gpu_poll_ticks = max(
1, int(self._config.hardware_poll_interval / tick_interval)
)
# Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd
if self._io_capture:
self._io_capture.start()
# Signal that IO capture is live so start() can return
if hasattr(self, "_io_capture_ready"):
self._io_capture_ready.set()
try:
with Live(
layout,
console=console,
refresh_per_second=self._config.refresh_rate,
screen=True,
redirect_stdout=False,
redirect_stderr=False,
) as live:
while not self._stop_event.is_set():
self._drain_queue()
# Poll GPU stats periodically
gpu_poll_counter += 1
if gpu_poll_counter >= gpu_poll_ticks:
gpu_poll_counter = 0
if self._gpu_poller.available:
self._state.gpus = self._gpu_poller.poll()
# Update elapsed time
self._state.elapsed_seconds = (
time.time() - self._state.start_time.timestamp()
)
self._update_layout(layout)
live.update(layout)
time.sleep(tick_interval)
# Final drain
self._drain_queue()
self._update_layout(layout)
live.update(layout)
finally:
if self._io_capture:
self._io_capture.stop()
try:
tty_file.close()
except Exception:
pass

88
src/axolotl/tui/state.py Normal file
View File

@@ -0,0 +1,88 @@
"""TUI shared data model — dataclasses for the dashboard state."""
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class GPUStats:
id: int
name: str
util_pct: float
vram_used_gb: float
vram_total_gb: float
temp_c: int
power_w: float | None
@dataclass
class LogLine:
timestamp: datetime
level: str # "info" | "debug" | "warning" | "error"
message: str
@dataclass
class CompletionSample:
step: int
prompt: str
completion: str
reward: float | None
advantage: float | None
@dataclass
class TUIState:
# Run metadata
run_name: str = ""
model_name: str = ""
training_mode: str = "sft"
world_size: int = 1
start_time: datetime = field(default_factory=datetime.now)
# Progress
current_step: int = 0
total_steps: int = 0
current_epoch: float = 0.0
total_epochs: float = 1.0
elapsed_seconds: float = 0.0
eta_seconds: float | None = None
# Training metrics (rolling window + current)
loss: float | None = None
grad_norm: float | None = None
learning_rate: float | None = None
tokens_per_second: float | None = None
samples_per_second: float | None = None
mfu: float | None = None
# RL-specific (None for non-RL modes)
rewards_mean: float | None = None
rewards_std: float | None = None
kl_divergence: float | None = None
clip_ratio: float | None = None
queue_size: int | None = None
# Per-GPU hardware (list indexed by local rank)
gpus: list[GPUStats] = field(default_factory=list)
# Recent log lines
log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200))
# Recent completions (GRPO/SFT with log_completions)
completions: deque[CompletionSample] = field(
default_factory=lambda: deque(maxlen=20)
)
# Loss history for sparkline
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
# DeepSpeed zero stage (None if not using DeepSpeed)
zero_stage: int | None = None
# Arbitrary plugin state
extra: dict[str, Any] = field(default_factory=dict)

View File

@@ -13,6 +13,7 @@ from pydantic import (
model_validator,
)
from axolotl.tui.config import TUIConfig
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
@@ -140,6 +141,12 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
tui: TUIConfig | None = Field(
default=None,
json_schema_extra={
"description": "TUI dashboard configuration. Set enabled: true to activate."
},
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(