diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index c0ac32050..0c3484133 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -91,6 +91,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs): type=click.Path(exists=True, path_type=str), help="YAML config for sweeping hyperparameters", ) +@click.option("--tui", is_flag=True, default=False, help="Enable TUI dashboard") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) @filter_none_kwargs @@ -101,6 +102,7 @@ def train( launcher: Literal["accelerate", "torchrun", "python"] = "accelerate", cloud: str | None = None, sweep: str | None = None, + tui: bool = False, **kwargs, ): """ @@ -118,6 +120,10 @@ def train( # Extract launcher args from extra args (after --) launcher_args = ctx.args if ctx.args else [] + # Handle --tui flag: set env var so subprocess workers pick it up + if tui: + os.environ["AXOLOTL_TUI"] = "1" + # Handle Ray launcher override _launcher = None if kwargs.get("use_ray") else launcher diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 6b3bfbd57..ac7354cd9 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -34,22 +34,92 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): if int(os.getenv("LOCAL_RANK", "0")) == 0: check_user_token() - plugin_manager = PluginManager.get_instance() - dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False) - if not dataset_meta: - if cfg.rl: - dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) - else: - dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + # Start TUI early (before data loading) so it captures preprocessing events + tui_renderer = None + tui_queue = None + is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0 + if is_rank_0: + from axolotl.train import _is_tui_enabled - model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + if _is_tui_enabled(cfg): + import queue as _queue - del model, tokenizer, trainer + from axolotl.tui.config import TUIConfig + from axolotl.tui.renderer import TUIRenderer + from axolotl.train import _get_tui_config - gc.collect() + tui_config_dict = _get_tui_config(cfg) + tui_config = TUIConfig(**tui_config_dict) if isinstance(tui_config_dict, dict) else tui_config_dict + tui_queue = _queue.Queue(maxsize=4096) + tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue) - plugin_manager = PluginManager.get_instance() - plugin_manager.post_train_unload(cfg) + # Send initial run info + model_name = cfg.base_model or "" + training_mode = str(cfg.rl) if cfg.rl else "sft" + world_size = int(os.environ.get("WORLD_SIZE", 1)) + try: + tui_queue.put_nowait({ + "type": "run_info", + "model_name": model_name, + "training_mode": training_mode, + "world_size": world_size, + }) + except _queue.Full: + pass + + tui_renderer.start() + + # Attach logging handler early + import logging + + from axolotl.tui.callback import _TUILogHandler + + _early_log_handler = _TUILogHandler(tui_queue, min_level=tui_config.log_level) + _early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s")) + # Attach to BOTH root and axolotl loggers because axolotl logger + # has propagate=False so root handler never sees axolotl.* messages + root_logger = logging.getLogger() + root_logger.addHandler(_early_log_handler) + axolotl_logger = logging.getLogger("axolotl") + axolotl_logger.addHandler(_early_log_handler) + + # Stash refs on cfg so train() can reuse the renderer + cfg._tui_renderer = tui_renderer + cfg._tui_queue = tui_queue + cfg._tui_early_log_handler = _early_log_handler + + try: + plugin_manager = PluginManager.get_instance() + dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False) + if not dataset_meta: + if cfg.rl: + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) + else: + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta) + + del model, tokenizer, trainer + + gc.collect() + + plugin_manager = PluginManager.get_instance() + plugin_manager.post_train_unload(cfg) + finally: + # If the TUI renderer started early but train() didn't get to stop it + # (e.g., error during data loading), clean up here + if tui_renderer is not None and not tui_renderer._stop_event.is_set(): + try: + tui_queue.put_nowait({"type": "done"}) + except Exception: + pass + tui_renderer.stop() + # Remove early log handler from both root and axolotl loggers + if hasattr(cfg, "_tui_early_log_handler"): + import logging + + logging.getLogger().removeHandler(cfg._tui_early_log_handler) + logging.getLogger("axolotl").removeHandler(cfg._tui_early_log_handler) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 522dd7e28..85629bca2 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -554,6 +554,36 @@ def setup_model_and_trainer( ) +def _is_tui_enabled(cfg: DictDefault) -> bool: + """Check if TUI is enabled via config or environment variable.""" + if os.environ.get("AXOLOTL_TUI", "").lower() in ("1", "true", "yes"): + return True + tui = cfg.get("tui") + if tui is None: + return False + if isinstance(tui, bool): + return tui + if isinstance(tui, dict): + return tui.get("enabled", False) + if hasattr(tui, "enabled"): + return tui.enabled + return False + + +def _get_tui_config(cfg: DictDefault) -> dict: + """Extract TUI config dict from cfg.""" + tui = cfg.get("tui") + if tui is None or isinstance(tui, bool): + return {"enabled": True} + if isinstance(tui, dict): + return {**tui, "enabled": True} + if hasattr(tui, "model_dump"): + d = tui.model_dump() + d["enabled"] = True + return d + return {"enabled": True} + + @send_errors def train( cfg: DictDefault, dataset_meta: TrainDatasetMeta @@ -577,6 +607,39 @@ def train( processor, ) = setup_model_and_trainer(cfg, dataset_meta) + # Register TUI callback if enabled and rank 0 + tui_enabled = _is_tui_enabled(cfg) + if tui_enabled and cfg.local_rank == 0: + from axolotl.tui import AxolotlTUICallback + from axolotl.tui.config import TUIConfig + + tui_config = _get_tui_config(cfg) + tui_config_obj = TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config + + # Reuse the early-started renderer if available (started in do_train) + early_renderer = getattr(cfg, "_tui_renderer", None) + early_queue = getattr(cfg, "_tui_queue", None) + + tui_callback = AxolotlTUICallback(config=tui_config_obj) + if early_renderer is not None and early_queue is not None: + # Reuse the already-running renderer and queue + tui_callback._renderer = early_renderer + tui_callback._queue = early_queue + tui_callback._renderer_started_early = True + trainer.add_callback(tui_callback) + + # Send model info to the callback + model_name = cfg.base_model or "" + training_mode = str(cfg.rl) if cfg.rl else "sft" + world_size = int(os.environ.get("WORLD_SIZE", 1)) + tui_callback._put({ + "type": "run_info", + "model_name": model_name, + "training_mode": training_mode, + "world_size": world_size, + }) + LOG.info("TUI dashboard enabled") + # Handle untrained tokens if configured train_dataset = dataset_meta.train_dataset handle_untrained_tokens_fix(cfg, model, tokenizer, train_dataset) diff --git a/src/axolotl/tui/__init__.py b/src/axolotl/tui/__init__.py new file mode 100644 index 000000000..e01811814 --- /dev/null +++ b/src/axolotl/tui/__init__.py @@ -0,0 +1,17 @@ +"""Axolotl Training TUI — rich-based terminal dashboard for monitoring training runs.""" + +from axolotl.tui.callback import AxolotlTUICallback +from axolotl.tui.config import TUIConfig +from axolotl.tui.io_capture import LineParser, register_parser +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + +__all__ = [ + "AxolotlTUICallback", + "BasePanel", + "LineParser", + "TUIConfig", + "TUIState", + "register_panel", + "register_parser", +] diff --git a/src/axolotl/tui/callback.py b/src/axolotl/tui/callback.py new file mode 100644 index 000000000..57db5c519 --- /dev/null +++ b/src/axolotl/tui/callback.py @@ -0,0 +1,138 @@ +"""AxolotlTUICallback — HF TrainerCallback that feeds metrics to the TUI.""" + +from __future__ import annotations + +import logging +import queue +from typing import Any + +from transformers.trainer_callback import TrainerCallback + +from axolotl.tui.config import TUIConfig +from axolotl.tui.renderer import TUIRenderer + + +class _TUILogHandler(logging.Handler): + """Logging handler that pushes log records into the TUI metric queue.""" + + _LEVEL_MAP = { + logging.DEBUG: "debug", + logging.INFO: "info", + logging.WARNING: "warning", + logging.ERROR: "error", + logging.CRITICAL: "error", + } + + def __init__(self, metric_queue: queue.Queue, min_level: str = "info"): + super().__init__() + level_name = min_level.upper() + self.setLevel(getattr(logging, level_name, logging.INFO)) + self._queue = metric_queue + + def emit(self, record: logging.LogRecord) -> None: + try: + level = self._LEVEL_MAP.get(record.levelno, "info") + msg = self.format(record) + self._queue.put_nowait({ + "type": "log_line", + "level": level, + "message": msg, + }) + except queue.Full: + pass + except Exception: + self.handleError(record) + + +class AxolotlTUICallback(TrainerCallback): + """Pushes training metrics into a queue for the TUI renderer. + + The callback never blocks on the render thread. The queue is bounded + (maxsize=512) with put_nowait; overflow is silently dropped. + """ + + def __init__(self, config: TUIConfig): + self._config = config + self._queue: queue.Queue = queue.Queue(maxsize=4096) + self._renderer = TUIRenderer(config=config, metric_queue=self._queue) + self._log_handler: _TUILogHandler | None = None + self._renderer_started_early: bool = False + + def _put(self, event: dict) -> None: + try: + self._queue.put_nowait(event) + except queue.Full: + pass + + def on_train_begin(self, args, state, control, model=None, **kwargs): + # Send run info + run_name = getattr(args, "run_name", "") or "" + self._put( + { + "type": "run_info", + "run_name": run_name, + "total_steps": state.max_steps, + "total_epochs": int(args.num_train_epochs) if args.num_train_epochs else 1, + } + ) + + if not self._renderer_started_early: + # Attach a logging handler to feed log messages into the events panel + self._log_handler = _TUILogHandler( + self._queue, min_level=self._config.log_level + ) + self._log_handler.setFormatter( + logging.Formatter("[%(name)s] %(message)s") + ) + # Attach to both root and axolotl loggers (axolotl has propagate=False) + logging.getLogger().addHandler(self._log_handler) + logging.getLogger("axolotl").addHandler(self._log_handler) + + # Start the renderer background thread + self._renderer.start() + + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is None: + return + + # Filter out non-numeric keys and internal keys + filtered = {} + for key, value in logs.items(): + if key.startswith("_"): + continue + if isinstance(value, (int, float)): + filtered[key] = value + elif isinstance(value, str): + # HF Trainer sometimes passes string-encoded numbers + try: + filtered[key] = float(value) + except (ValueError, TypeError): + pass + + if filtered: + self._put({"type": "metrics", "logs": filtered}) + + def on_step_end(self, args, state, control, **kwargs): + self._put( + { + "type": "step", + "step": state.global_step, + "total_steps": state.max_steps, + "epoch": state.epoch if state.epoch else 0, + } + ) + + def on_prediction_step(self, args, state, control, **kwargs): + pass + + def on_train_end(self, args, state, control, **kwargs): + self._put({"type": "done"}) + # If renderer was started early, do_train's finally block handles stop + if not self._renderer_started_early: + self._renderer.stop() + + # Remove the logging handler (only if we added it) + if self._log_handler: + logging.getLogger().removeHandler(self._log_handler) + logging.getLogger("axolotl").removeHandler(self._log_handler) + self._log_handler = None diff --git a/src/axolotl/tui/config.py b/src/axolotl/tui/config.py new file mode 100644 index 000000000..ba32f36d8 --- /dev/null +++ b/src/axolotl/tui/config.py @@ -0,0 +1,46 @@ +"""TUI configuration — Pydantic model for TUI settings.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class TUIConfig(BaseModel): + """Configuration for the Axolotl Training TUI dashboard.""" + + enabled: bool = Field( + default=False, + json_schema_extra={"description": "Enable the TUI dashboard"}, + ) + refresh_rate: int = Field( + default=4, + json_schema_extra={"description": "Renders per second"}, + ) + log_level: str = Field( + default="debug", + json_schema_extra={ + "description": "Minimum log level shown in events panel" + }, + ) + panels: list[str] = Field( + default_factory=lambda: ["progress", "training", "hardware", "events", "debug"], + json_schema_extra={ + "description": "Ordered list of panels to display" + }, + ) + hardware_poll_interval: int = Field( + default=2, + json_schema_extra={"description": "Seconds between pynvml GPU queries"}, + ) + stdout_log_path: str = Field( + default="axolotl_stdout.log", + json_schema_extra={ + "description": "File path for captured stdout/stderr log" + }, + ) + parser_plugins: list[str] = Field( + default_factory=list, + json_schema_extra={ + "description": "List of extra parser classes to load" + }, + ) diff --git a/src/axolotl/tui/gpu.py b/src/axolotl/tui/gpu.py new file mode 100644 index 000000000..8ea05b9ed --- /dev/null +++ b/src/axolotl/tui/gpu.py @@ -0,0 +1,72 @@ +"""GPU polling wrapper around pynvml with graceful fallback.""" + +from __future__ import annotations + +import logging + +from axolotl.tui.state import GPUStats + +LOG = logging.getLogger(__name__) + +_nvml_available = False +try: + import pynvml + + pynvml.nvmlInit() + _nvml_available = True +except Exception: + LOG.debug("pynvml unavailable — GPU stats will not be shown") + + +class GPUPoller: + """Polls local GPU stats via pynvml. Falls back gracefully if unavailable.""" + + def __init__(self): + self._device_count = 0 + if _nvml_available: + try: + self._device_count = pynvml.nvmlDeviceGetCount() + except Exception: + self._device_count = 0 + + @property + def available(self) -> bool: + return _nvml_available and self._device_count > 0 + + def poll(self) -> list[GPUStats]: + if not self.available: + return [] + + stats = [] + for i in range(self._device_count): + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + name = pynvml.nvmlDeviceGetName(handle) + if isinstance(name, bytes): + name = name.decode("utf-8") + + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + temp = pynvml.nvmlDeviceGetTemperature( + handle, pynvml.NVML_TEMPERATURE_GPU + ) + + try: + power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + except Exception: + power = None + + stats.append( + GPUStats( + id=i, + name=name, + util_pct=util.gpu, + vram_used_gb=mem.used / (1024**3), + vram_total_gb=mem.total / (1024**3), + temp_c=temp, + power_w=power, + ) + ) + except Exception: + pass + return stats diff --git a/src/axolotl/tui/io_capture.py b/src/axolotl/tui/io_capture.py new file mode 100644 index 000000000..9ae4b080b --- /dev/null +++ b/src/axolotl/tui/io_capture.py @@ -0,0 +1,193 @@ +"""I/O capture: OS-level stdout/stderr redirect, line parser chain, and parser registry.""" + +from __future__ import annotations + +import os +import queue +import re +import sys +import threading +from abc import ABC, abstractmethod +from datetime import datetime +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + pass + +# --------------------------------------------------------------------------- +# Parser registry +# --------------------------------------------------------------------------- + +_parser_registry: list[type[LineParser]] = [] + + +def register_parser(cls: type[LineParser]) -> type[LineParser]: + """Decorator to register a LineParser subclass.""" + _parser_registry.append(cls) + return cls + + +def get_registered_parsers() -> list[type[LineParser]]: + return list(_parser_registry) + + +# --------------------------------------------------------------------------- +# Base LineParser +# --------------------------------------------------------------------------- + + +class LineParser(ABC): + """Base class for stdout/stderr line parsers.""" + + priority: int = 50 + name: str = "" + + @abstractmethod + def parse(self, line: str, source: str) -> list[dict]: + """Parse a single captured line. + + Args: + line: one line of captured output, trailing newline stripped. + source: "stdout" or "stderr". + + Returns: + List of event dicts to push onto the metric queue. + Return [] if this line is not relevant. + """ + ... + + +# --------------------------------------------------------------------------- +# ParserChain +# --------------------------------------------------------------------------- + + +class ParserChain: + def __init__(self): + self._parsers: list[LineParser] = [] + + def register(self, parser: LineParser) -> None: + self._parsers.append(parser) + self._parsers.sort(key=lambda p: p.priority) + + def parse(self, line: str, source: str = "stdout") -> list[dict]: + events: list[dict] = [] + for parser in self._parsers: + events.extend(parser.parse(line, source)) + return events + + +# --------------------------------------------------------------------------- +# IOCapture — OS-level fd redirect to pipe +# --------------------------------------------------------------------------- + + +class IOCapture: + """Redirects fd 1 and fd 2 into an OS pipe, drains via a reader thread, + passes lines through a ParserChain, and tees to a log file.""" + + def __init__( + self, log_path: str, parser_chain: ParserChain, metric_queue: queue.Queue + ): + self._parser_chain = parser_chain + self._queue = metric_queue + self._log_path = log_path + self._log_file = None + self._thread: threading.Thread | None = None + self._read_fd: int | None = None + self._write_fd: int | None = None + self._saved_stdout_fd: int | None = None + self._saved_stderr_fd: int | None = None + + def start(self) -> None: + # Write run-start separator + self._log_file = open(self._log_path, "a", buffering=1) # noqa: SIM115 + self._log_file.write( + f"\n=== axolotl run started {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ===\n" + ) + self._log_file.flush() + + # OS-level pipe + self._read_fd, self._write_fd = os.pipe() + + # Save originals + self._saved_stdout_fd = os.dup(1) + self._saved_stderr_fd = os.dup(2) + + # Redirect both stdout and stderr into the write end + os.dup2(self._write_fd, 1) + os.dup2(self._write_fd, 2) + os.close(self._write_fd) # write end now held by fds 1 and 2 + + # Also redirect Python-level handles + sys.stdout = open(1, "w", buffering=1, closefd=False) # noqa: SIM115 + sys.stderr = open(2, "w", buffering=1, closefd=False) # noqa: SIM115 + + # Drain thread + self._thread = threading.Thread(target=self._drain, daemon=True) + self._thread.start() + + def stop(self) -> None: + # Restore fds — closes the write end, causing reader to see EOF + if self._saved_stdout_fd is not None: + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + os.dup2(self._saved_stdout_fd, 1) + os.dup2(self._saved_stderr_fd, 2) + os.close(self._saved_stdout_fd) + os.close(self._saved_stderr_fd) + self._saved_stdout_fd = None + self._saved_stderr_fd = None + + if self._thread is not None: + self._thread.join(timeout=2.0) + self._thread = None + + if self._log_file is not None: + self._log_file.close() + self._log_file = None + + def _drain(self) -> None: + # Read raw bytes and split on both \n and \r to handle tqdm progress bars + # which use \r for in-place updates without \n + with os.fdopen(self._read_fd, "rb") as pipe: + buf = b"" + while True: + chunk = pipe.read(4096) + if not chunk: + # EOF — process remaining buffer + if buf: + self._process_line(buf.decode("utf-8", errors="replace")) + break + buf += chunk + # Split on \n or \r + while b"\n" in buf or b"\r" in buf: + # Find the earliest delimiter + idx_n = buf.find(b"\n") + idx_r = buf.find(b"\r") + if idx_n == -1: + idx = idx_r + elif idx_r == -1: + idx = idx_n + else: + idx = min(idx_n, idx_r) + line = buf[:idx].decode("utf-8", errors="replace") + buf = buf[idx + 1 :] + # Handle \r\n as single delimiter + if buf.startswith(b"\n"): + buf = buf[1:] + if line: + self._process_line(line) + + def _process_line(self, line: str) -> None: + line = line.rstrip() + if not line: + return + if self._log_file: + self._log_file.write(line + "\n") + self._log_file.flush() + for event in self._parser_chain.parse(line): + try: + self._queue.put_nowait(event) + except queue.Full: + pass diff --git a/src/axolotl/tui/panels/__init__.py b/src/axolotl/tui/panels/__init__.py new file mode 100644 index 000000000..46c790d37 --- /dev/null +++ b/src/axolotl/tui/panels/__init__.py @@ -0,0 +1,64 @@ +"""Panel registry and base class for TUI panels.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from rich.console import RenderableType + +from axolotl.tui.state import TUIState + +# --------------------------------------------------------------------------- +# Panel registry +# --------------------------------------------------------------------------- + +_panel_registry: dict[str, type[BasePanel]] = {} + + +def register_panel(position: str = "bottom", weight: int = 50): + """Decorator to register a panel class with position and weight.""" + + def decorator(cls: type[BasePanel]) -> type[BasePanel]: + cls.position = position + cls.weight = weight + _panel_registry[cls.name] = cls + return cls + + return decorator + + +def get_registered_panels() -> dict[str, type[BasePanel]]: + return dict(_panel_registry) + + +# --------------------------------------------------------------------------- +# BasePanel +# --------------------------------------------------------------------------- + + +class BasePanel(ABC): + name: str = "" + position: str = "bottom" + weight: int = 50 + min_height: int = 4 + max_height: int | None = None + modes: list[str] = ["*"] + + @abstractmethod + def render(self, state: TUIState) -> RenderableType: + """Return a rich renderable. Called every tick.""" + ... + + def on_event(self, event: dict) -> None: + """Optional: react to raw metric events before state is merged.""" + pass + + +# Auto-import built-in panels to trigger registration +from axolotl.tui.panels.completions import CompletionsPanel # noqa: E402, F401 +from axolotl.tui.panels.debug import DebugPanel # noqa: E402, F401 +from axolotl.tui.panels.events import EventsPanel # noqa: E402, F401 +from axolotl.tui.panels.hardware import HardwarePanel # noqa: E402, F401 +from axolotl.tui.panels.progress import ProgressPanel # noqa: E402, F401 +from axolotl.tui.panels.training import TrainingPanel # noqa: E402, F401 diff --git a/src/axolotl/tui/panels/completions.py b/src/axolotl/tui/panels/completions.py new file mode 100644 index 000000000..43379ad7b --- /dev/null +++ b/src/axolotl/tui/panels/completions.py @@ -0,0 +1,58 @@ +"""CompletionsPanel — shows recent RL/log_completions samples.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + + +def _truncate(s: str, maxlen: int = 60) -> str: + return s[:maxlen] + "…" if len(s) > maxlen else s + + +@register_panel(position="bottom", weight=20) +class CompletionsPanel(BasePanel): + name = "completions" + min_height = 6 + modes = ["grpo", "dpo"] + + def render(self, state: TUIState) -> RenderableType: + if not state.completions: + return Panel( + Text("No completions yet...", style="dim"), + title="Completions", + border_style="magenta", + ) + + table = Table( + show_header=True, + header_style="bold", + expand=True, + box=None, + pad_edge=False, + ) + table.add_column("step", justify="right", width=6) + table.add_column("prompt", no_wrap=False, max_width=40) + table.add_column("completion", no_wrap=False, max_width=40) + table.add_column("reward", justify="right", width=8) + table.add_column("adv", justify="right", width=8) + + for sample in list(state.completions)[-5:]: + reward_str = f"{sample.reward:.2f}" if sample.reward is not None else "--" + adv_str = ( + f"{sample.advantage:+.2f}" if sample.advantage is not None else "--" + ) + table.add_row( + str(sample.step), + _truncate(sample.prompt), + _truncate(sample.completion), + reward_str, + adv_str, + ) + + return Panel(table, title="Completions", border_style="magenta") diff --git a/src/axolotl/tui/panels/debug.py b/src/axolotl/tui/panels/debug.py new file mode 100644 index 000000000..ed73ae261 --- /dev/null +++ b/src/axolotl/tui/panels/debug.py @@ -0,0 +1,35 @@ +"""DebugPanel — scrolling log of debug-level messages, separate from main events.""" + +from __future__ import annotations + +from collections import deque +from datetime import datetime + +from rich.console import RenderableType +from rich.panel import Panel +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import LogLine, TUIState + + +@register_panel(position="bottom", weight=30) +class DebugPanel(BasePanel): + name = "debug" + min_height = 6 + max_height = 10 + + def render(self, state: TUIState) -> RenderableType: + lines = Text() + # Show last 8 debug-level log lines + debug_lines = [l for l in state.log_lines if l.level == "debug"][-8:] + for log_line in debug_lines: + ts = log_line.timestamp.strftime("%H:%M:%S") + lines.append(f"[{ts}] ", style="dim") + lines.append(log_line.message[:200], style="dim") + lines.append("\n") + + if not debug_lines: + lines = Text("No debug messages yet...", style="dim") + + return Panel(lines, title="Debug", border_style="dim") diff --git a/src/axolotl/tui/panels/events.py b/src/axolotl/tui/panels/events.py new file mode 100644 index 000000000..a16de25d0 --- /dev/null +++ b/src/axolotl/tui/panels/events.py @@ -0,0 +1,43 @@ +"""EventsPanel — scrolling log of recent events, color-coded by level.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + +_LEVEL_STYLES = { + "debug": "dim", + "info": "", + "warning": "yellow", + "error": "red bold", + "critical": "red bold", +} + + +@register_panel(position="bottom", weight=10) +class EventsPanel(BasePanel): + name = "events" + min_height = 8 + max_height = 20 + + def render(self, state: TUIState) -> RenderableType: + lines = Text() + # Show last 15 non-debug log lines (debug goes to DebugPanel) + recent = [l for l in state.log_lines if l.level != "debug"][-15:] + for log_line in recent: + ts = log_line.timestamp.strftime("%H:%M:%S") + level = log_line.level.upper() + style = _LEVEL_STYLES.get(log_line.level, "") + lines.append(f"[{ts}] ", style="dim") + lines.append(f"[{level}] ", style=style or "") + lines.append(log_line.message[:200], style=style or "") + lines.append("\n") + + if not recent: + lines = Text("No events yet...", style="dim") + + return Panel(lines, title="Events", border_style="yellow") diff --git a/src/axolotl/tui/panels/hardware.py b/src/axolotl/tui/panels/hardware.py new file mode 100644 index 000000000..fdee496f5 --- /dev/null +++ b/src/axolotl/tui/panels/hardware.py @@ -0,0 +1,80 @@ +"""HardwarePanel — per-GPU stats via pynvml.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + +_BAR_FULL = "█" +_BAR_EMPTY = "░" + + +def _util_bar(pct: float, width: int = 6) -> Text: + filled = int(pct / 100 * width) + bar = _BAR_FULL * filled + _BAR_EMPTY * (width - filled) + color = "green" if pct < 70 else ("yellow" if pct < 90 else "red") + return Text.assemble((bar, color), f" {pct:3.0f}%") + + +@register_panel(position="right", weight=10) +class HardwarePanel(BasePanel): + name = "hardware" + min_height = 6 + + def render(self, state: TUIState) -> RenderableType: + if not state.gpus: + return Panel( + Text("GPU stats unavailable", style="dim"), + title="Hardware", + border_style="green", + ) + + table = Table( + show_header=True, + header_style="bold", + expand=True, + box=None, + pad_edge=False, + ) + table.add_column("id", justify="right", width=3) + table.add_column("util", no_wrap=True) + table.add_column("vram", no_wrap=True) + table.add_column("°C", justify="right", width=4) + table.add_column("W", justify="right", width=5) + + total_vram_used = 0.0 + total_vram_total = 0.0 + total_util = 0.0 + + for gpu in state.gpus: + total_vram_used += gpu.vram_used_gb + total_vram_total += gpu.vram_total_gb + total_util += gpu.util_pct + + power_str = f"{gpu.power_w:.0f}" if gpu.power_w is not None else "--" + table.add_row( + str(gpu.id), + _util_bar(gpu.util_pct), + f"{gpu.vram_used_gb:.1f}/{gpu.vram_total_gb:.1f} GB", + str(gpu.temp_c), + power_str, + ) + + # Footer with aggregates + n = len(state.gpus) + if n > 1: + avg_util = total_util / n + table.add_row( + "Σ", + Text(f"avg {avg_util:.0f}%", style="dim"), + Text(f"{total_vram_used:.1f}/{total_vram_total:.1f} GB", style="dim"), + "", + "", + ) + + return Panel(table, title="Hardware", border_style="green") diff --git a/src/axolotl/tui/panels/progress.py b/src/axolotl/tui/panels/progress.py new file mode 100644 index 000000000..871ed297b --- /dev/null +++ b/src/axolotl/tui/panels/progress.py @@ -0,0 +1,73 @@ +"""ProgressPanel — top-bar progress display with step count, elapsed, ETA.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.progress import BarColumn, Progress, TextColumn +from rich.table import Table +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + + +def _fmt_time(seconds: float | None) -> str: + if seconds is None or seconds < 0: + return "--:--:--" + h = int(seconds) // 3600 + m = (int(seconds) % 3600) // 60 + s = int(seconds) % 60 + return f"{h}:{m:02d}:{s:02d}" + + +def _fmt_eta(seconds: float | None) -> str: + if seconds is None or seconds < 0: + return "eta --" + h = int(seconds) // 3600 + m = (int(seconds) % 3600) // 60 + if h > 0: + return f"eta {h}h{m:02d}m" + return f"eta {m}m{int(seconds) % 60:02d}s" + + +@register_panel(position="top", weight=10) +class ProgressPanel(BasePanel): + name = "progress" + min_height = 3 + max_height = 3 + + def render(self, state: TUIState) -> RenderableType: + pct = ( + (state.current_step / state.total_steps * 100) + if state.total_steps > 0 + else 0 + ) + + # Header line + mode_upper = state.training_mode.upper() if state.training_mode else "SFT" + model_short = state.model_name.split("/")[-1] if state.model_name else "model" + header = Text.assemble( + ("● ", "bold green"), + ("AXOLOTL", "bold cyan"), + f" {mode_upper} · {model_short} ", + ( + f"{state.current_step} / {state.total_steps}", + "bold", + ), + f" · {_fmt_time(state.elapsed_seconds)} elapsed · {_fmt_eta(state.eta_seconds)} · {pct:.1f}%", + ) + + # Progress bar + progress = Progress( + TextColumn(""), + BarColumn(bar_width=None), + TextColumn("{task.percentage:>3.0f}%"), + expand=True, + ) + task = progress.add_task("", total=state.total_steps or 1) + progress.update(task, completed=state.current_step) + + table = Table.grid(expand=True) + table.add_row(header) + table.add_row(progress) + return table diff --git a/src/axolotl/tui/panels/training.py b/src/axolotl/tui/panels/training.py new file mode 100644 index 000000000..cb537f9c2 --- /dev/null +++ b/src/axolotl/tui/panels/training.py @@ -0,0 +1,97 @@ +"""TrainingPanel — live scalar metrics table with loss sparkline.""" + +from __future__ import annotations + +from rich.console import RenderableType +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from axolotl.tui.panels import BasePanel, register_panel +from axolotl.tui.state import TUIState + +# Braille sparkline characters (8 levels) +_SPARK_CHARS = "▁▂▃▄▅▆▇█" + + +def _sparkline(values: list[float] | None, width: int = 20) -> str: + if not values or len(values) < 2: + return "" + vals = list(values)[-width:] + lo, hi = min(vals), max(vals) + rng = hi - lo if hi != lo else 1.0 + return "".join(_SPARK_CHARS[min(int((v - lo) / rng * 7), 7)] for v in vals) + + +# Known key ordering and formatting +_KNOWN_KEYS: list[tuple[str, str, str]] = [ + ("loss", "loss", ".4f"), + ("grad_norm", "grad norm", ".3f"), + ("learning_rate", "lr", ".2e"), + ("tokens_per_second", "tok/s", ".1f"), + ("samples_per_second", "samples/s", ".1f"), + ("mfu", "MFU", ".1f"), + # RL-specific + ("rewards_mean", "rewards/mean", ".4f"), + ("rewards_std", "rewards/std", ".4f"), + ("kl_divergence", "KL", ".4f"), + ("clip_ratio", "clip ratio", ".3f"), + ("queue_size", "queue", "d"), +] + + +@register_panel(position="left", weight=10) +class TrainingPanel(BasePanel): + name = "training" + min_height = 8 + + def render(self, state: TUIState) -> RenderableType: + table = Table( + show_header=True, + header_style="bold", + expand=True, + box=None, + pad_edge=False, + ) + table.add_column("metric", style="cyan", no_wrap=True) + table.add_column("value", justify="right") + table.add_column("trend", justify="left", no_wrap=True) + + for attr, label, fmt in _KNOWN_KEYS: + val = getattr(state, attr, None) + if val is None: + # Also check extra dict + val = state.extra.get(attr) + if val is None: + continue + + try: + formatted = f"{val:{fmt}}" + except (ValueError, TypeError): + formatted = str(val) + + trend = "" + if attr == "loss": + trend = _sparkline(list(state.loss_history)) + + table.add_row(label, formatted, trend) + + # Any extra keys not in _KNOWN_KEYS + known_attrs = {k for k, _, _ in _KNOWN_KEYS} + for key, val in sorted(state.extra.items()): + if key in known_attrs or val is None: + continue + try: + formatted = f"{val:.4f}" + except (ValueError, TypeError): + formatted = str(val) + table.add_row(key, formatted, "") + + if table.row_count == 0: + return Panel( + Text("Waiting for first log step...", style="dim"), + title="Training", + border_style="blue", + ) + + return Panel(table, title="Training", border_style="blue") diff --git a/src/axolotl/tui/parsers/__init__.py b/src/axolotl/tui/parsers/__init__.py new file mode 100644 index 000000000..8e4fbdb85 --- /dev/null +++ b/src/axolotl/tui/parsers/__init__.py @@ -0,0 +1,7 @@ +"""Built-in line parsers — auto-imported to trigger @register_parser decorators.""" + +from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401 +from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401 +from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401 +from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401 +from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401 diff --git a/src/axolotl/tui/parsers/deepspeed.py b/src/axolotl/tui/parsers/deepspeed.py new file mode 100644 index 000000000..af2acf48b --- /dev/null +++ b/src/axolotl/tui/parsers/deepspeed.py @@ -0,0 +1,29 @@ +"""DeepSpeedParser — extracts DeepSpeed stage info and throughput metrics.""" + +from __future__ import annotations + +import re + +from axolotl.tui.io_capture import LineParser, register_parser + + +@register_parser +class DeepSpeedParser(LineParser): + priority = 20 + name = "deepspeed" + + _SAMPLES_RE = re.compile(r"samples/sec=([0-9.]+)") + _STAGE_RE = re.compile(r"ZeRO Stage (\d)") + + def parse(self, line: str, source: str) -> list[dict]: + events: list[dict] = [] + if m := self._SAMPLES_RE.search(line): + events.append( + { + "type": "metrics", + "logs": {"samples_per_second": float(m.group(1))}, + } + ) + if m := self._STAGE_RE.search(line): + events.append({"type": "run_info", "zero_stage": int(m.group(1))}) + return events diff --git a/src/axolotl/tui/parsers/nccl.py b/src/axolotl/tui/parsers/nccl.py new file mode 100644 index 000000000..0be6eaa32 --- /dev/null +++ b/src/axolotl/tui/parsers/nccl.py @@ -0,0 +1,27 @@ +"""NCCLErrorParser — surfaces NCCL errors as red alert events.""" + +from __future__ import annotations + +import re + +from axolotl.tui.io_capture import LineParser, register_parser + + +@register_parser +class NCCLErrorParser(LineParser): + priority = 10 + name = "nccl_error" + + _RE = re.compile(r"NCCL error|Unhandled NCCL", re.IGNORECASE) + + def parse(self, line: str, source: str) -> list[dict]: + if self._RE.search(line): + return [ + { + "type": "log_line", + "level": "error", + "message": f"⚠ NCCL: {line}", + }, + {"type": "alert", "severity": "error", "message": line}, + ] + return [] diff --git a/src/axolotl/tui/parsers/raw_log.py b/src/axolotl/tui/parsers/raw_log.py new file mode 100644 index 000000000..8f3f11d07 --- /dev/null +++ b/src/axolotl/tui/parsers/raw_log.py @@ -0,0 +1,37 @@ +"""RawLogParser — catches every line as a log_line event.""" + +from __future__ import annotations + +import re + +from axolotl.tui.io_capture import LineParser, register_parser + + +@register_parser +class RawLogParser(LineParser): + priority = 99 + name = "raw_log" + + _LOG_RE = re.compile( + r"^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)" + r"\s*[-–]\s*(?PDEBUG|INFO|WARNING|ERROR|CRITICAL)" + r"\s*[-–]\s*(?P.+)$", + re.IGNORECASE, + ) + + # Filter out tqdm progress bar lines and other noisy output + _TQDM_RE = re.compile(r"^\s*\d+%\|.*\|") + _EMPTY_RE = re.compile(r"^\s*$") + + def parse(self, line: str, source: str) -> list[dict]: + # Skip empty lines and tqdm progress bar updates + if self._EMPTY_RE.match(line) or self._TQDM_RE.match(line): + return [] + + m = self._LOG_RE.match(line) + level = ( + m.group("level").lower() + if m + else ("error" if source == "stderr" else "info") + ) + return [{"type": "log_line", "level": level, "message": line}] diff --git a/src/axolotl/tui/parsers/torch_compile.py b/src/axolotl/tui/parsers/torch_compile.py new file mode 100644 index 000000000..07720f61d --- /dev/null +++ b/src/axolotl/tui/parsers/torch_compile.py @@ -0,0 +1,26 @@ +"""TorchCompileParser — detects torch.compile graph breaks and recompilations.""" + +from __future__ import annotations + +import re + +from axolotl.tui.io_capture import LineParser, register_parser + + +@register_parser +class TorchCompileParser(LineParser): + priority = 20 + name = "torch_compile" + + _RE = re.compile(r"Graph break|Recompiling|torch\.compile", re.IGNORECASE) + + def parse(self, line: str, source: str) -> list[dict]: + if self._RE.search(line): + return [ + { + "type": "log_line", + "level": "warning", + "message": f"⚡ compile: {line}", + } + ] + return [] diff --git a/src/axolotl/tui/parsers/tqdm.py b/src/axolotl/tui/parsers/tqdm.py new file mode 100644 index 000000000..2ad30ae04 --- /dev/null +++ b/src/axolotl/tui/parsers/tqdm.py @@ -0,0 +1,61 @@ +"""TqdmParser — captures tqdm progress bar output and surfaces as structured events.""" + +from __future__ import annotations + +import re + +from axolotl.tui.io_capture import LineParser, register_parser + + +@register_parser +class TqdmParser(LineParser): + priority = 15 + name = "tqdm" + + # Match tqdm-style progress lines, e.g.: + # Tokenizing Prompts (num_proc=24): 35%|███▍ | 19008/54568 [00:02<00:02, 17417.65 examples/s] + # Loading weights: 53%|█████▎ | 77/146 [00:00<00:00, 396.39it/s] + # 0%| | 0/30 [00:00.*?)\s*" + r"(?P\d+)%\|[▏▎▍▌▋▊▉█░▓▒# ]*\|\s*" + r"(?P[\d,]+)/(?P[\d,]+)" + r"\s*\[(?P[^\]]*)\]" + ) + + # Also match simpler forms like: + # Fetching 0 files: 0it [00:00, ?it/s] + _FETCH_RE = re.compile( + r"(?P[\w\s]+):\s*(?P\d+)(?:it)?\s*\[.*?\]" + ) + + def parse(self, line: str, source: str) -> list[dict]: + m = self._TQDM_RE.search(line) + if m: + desc = m.group("desc").strip().rstrip(":") + pct = int(m.group("pct")) + current = int(m.group("current").replace(",", "")) + total = int(m.group("total").replace(",", "")) + + events: list[dict] = [] + + # Surface as a log line with progress info + if pct == 100 or pct == 0 or pct % 25 == 0: + msg = f"[{desc}] {pct}% ({current}/{total})" if desc else f"{pct}% ({current}/{total})" + events.append({ + "type": "log_line", + "level": "info", + "message": msg, + }) + + # Also emit as a progress metric + events.append({ + "type": "metrics", + "logs": { + f"progress/{desc.lower().replace(' ', '_')}": pct / 100.0, + }, + }) + + return events + + return [] diff --git a/src/axolotl/tui/renderer.py b/src/axolotl/tui/renderer.py new file mode 100644 index 000000000..8384ce75a --- /dev/null +++ b/src/axolotl/tui/renderer.py @@ -0,0 +1,430 @@ +"""TUIRenderer — background daemon thread that drives the rich.live.Live display.""" + +from __future__ import annotations + +import importlib +import logging +import queue +import threading +import time +from datetime import datetime +from typing import Any + +from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from rich.text import Text + +from axolotl.tui.config import TUIConfig +from axolotl.tui.gpu import GPUPoller +from axolotl.tui.io_capture import ( + IOCapture, + ParserChain, + get_registered_parsers, +) +from axolotl.tui.panels import BasePanel, get_registered_panels +from axolotl.tui.state import CompletionSample, LogLine, TUIState + +LOG = logging.getLogger(__name__) + + +class TUIRenderer: + """Background thread that renders the TUI dashboard using rich.live.Live.""" + + def __init__(self, config: TUIConfig, metric_queue: queue.Queue): + self._config = config + self._queue = metric_queue + self._state = TUIState() + self._gpu_poller = GPUPoller() + self._panels: list[BasePanel] = [] + self._thread: threading.Thread | None = None + self._stop_event = threading.Event() + self._io_capture: IOCapture | None = None + self._parser_chain: ParserChain | None = None + + def _init_panels(self) -> None: + registry = get_registered_panels() + for panel_name in self._config.panels: + if panel_name in registry: + self._panels.append(registry[panel_name]()) + + def _init_parser_chain(self) -> None: + self._parser_chain = ParserChain() + # Register all built-in parsers + for parser_cls in get_registered_parsers(): + self._parser_chain.register(parser_cls()) + + # Load plugin parsers + for plugin_spec in self._config.parser_plugins: + try: + if "::" in plugin_spec: + # file path :: class name + file_path, class_name = plugin_spec.split("::", 1) + import importlib.util + + spec = importlib.util.spec_from_file_location( + "custom_parser", file_path + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + parser_cls = getattr(mod, class_name) + else: + # dotted module path + module_path, class_name = plugin_spec.rsplit(".", 1) + mod = importlib.import_module(module_path) + parser_cls = getattr(mod, class_name) + self._parser_chain.register(parser_cls()) + except Exception as exc: + LOG.warning(f"Failed to load parser plugin {plugin_spec}: {exc}") + + def _build_layout(self) -> Layout: + layout = Layout() + + top_panels = [p for p in self._panels if p.position == "top"] + left_panels = [p for p in self._panels if p.position == "left"] + right_panels = [p for p in self._panels if p.position == "right"] + bottom_panels = [p for p in self._panels if p.position == "bottom"] + + sections = [] + + if top_panels: + layout_top = Layout(name="top", size=3) + sections.append(layout_top) + + if left_panels or right_panels: + layout_middle = Layout(name="middle", ratio=3) + middle_parts = [] + if left_panels: + middle_parts.append(Layout(name="left", ratio=1)) + if right_panels: + middle_parts.append(Layout(name="right", ratio=1)) + if middle_parts: + layout_middle.split_row(*middle_parts) + sections.append(layout_middle) + + if bottom_panels: + layout_bottom = Layout(name="bottom", ratio=2) + if len(bottom_panels) > 1: + layout_bottom.split_row( + *[ + Layout(name=f"bottom_{i}", ratio=1) + for i in range(len(bottom_panels)) + ] + ) + sections.append(layout_bottom) + + if sections: + layout.split_column(*sections) + + return layout + + def _update_layout(self, layout: Layout) -> None: + top_panels = [p for p in self._panels if p.position == "top"] + left_panels = [p for p in self._panels if p.position == "left"] + right_panels = [p for p in self._panels if p.position == "right"] + bottom_panels = [p for p in self._panels if p.position == "bottom"] + + if top_panels: + layout["top"].update(top_panels[0].render(self._state)) + + if left_panels: + layout["left"].update(left_panels[0].render(self._state)) + + if right_panels: + layout["right"].update(right_panels[0].render(self._state)) + + if bottom_panels: + if len(bottom_panels) == 1: + layout["bottom"].update(bottom_panels[0].render(self._state)) + else: + for i, panel in enumerate(bottom_panels): + layout[f"bottom_{i}"].update(panel.render(self._state)) + + def _drain_queue(self) -> None: + while True: + try: + event = self._queue.get_nowait() + except queue.Empty: + break + + # Dispatch event to panels first + for panel in self._panels: + panel.on_event(event) + + event_type = event.get("type") + + if event_type == "metrics": + logs = event.get("logs", {}) + self._apply_metrics(logs) + + elif event_type == "step": + self._state.current_step = event.get("step", self._state.current_step) + self._state.total_steps = event.get( + "total_steps", self._state.total_steps + ) + self._state.current_epoch = event.get( + "epoch", self._state.current_epoch + ) + now = time.time() + self._state.elapsed_seconds = now - self._state.start_time.timestamp() + if ( + self._state.current_step > 0 + and self._state.total_steps > 0 + ): + rate = self._state.elapsed_seconds / self._state.current_step + remaining = self._state.total_steps - self._state.current_step + self._state.eta_seconds = rate * remaining + + elif event_type == "log_line": + level = event.get("level", "info") + message = event.get("message", "") + self._state.log_lines.append( + LogLine( + timestamp=datetime.now(), + level=level, + message=message, + ) + ) + + elif event_type == "completion": + self._state.completions.append( + CompletionSample( + step=event.get("step", 0), + prompt=event.get("prompt", ""), + completion=event.get("completion", ""), + reward=event.get("reward"), + advantage=event.get("advantage"), + ) + ) + + elif event_type == "run_info": + if "run_name" in event: + self._state.run_name = event["run_name"] + if "model_name" in event: + self._state.model_name = event["model_name"] + if "training_mode" in event: + self._state.training_mode = event["training_mode"] + if "world_size" in event: + self._state.world_size = event["world_size"] + if "total_steps" in event: + self._state.total_steps = event["total_steps"] + if "total_epochs" in event: + self._state.total_epochs = event["total_epochs"] + + elif event_type == "done": + self._stop_event.set() + + def _apply_metrics(self, logs: dict[str, Any]) -> None: + metric_map = { + "loss": "loss", + "grad_norm": "grad_norm", + "learning_rate": "learning_rate", + "tokens_per_second": "tokens_per_second", + "samples_per_second": "samples_per_second", + "mfu": "mfu", + "rewards/mean": "rewards_mean", + "rewards_mean": "rewards_mean", + "rewards/std": "rewards_std", + "rewards_std": "rewards_std", + "kl": "kl_divergence", + "kl_divergence": "kl_divergence", + "clip_ratio": "clip_ratio", + "queue_size": "queue_size", + } + + for key, value in logs.items(): + if key in metric_map: + setattr(self._state, metric_map[key], value) + else: + self._state.extra[key] = value + + if "loss" in logs and logs["loss"] is not None: + self._state.loss_history.append(logs["loss"]) + + def start(self) -> None: + self._init_panels() + self._init_parser_chain() + + # Set up I/O capture + self._io_capture = IOCapture( + log_path=self._config.stdout_log_path, + parser_chain=self._parser_chain, + metric_queue=self._queue, + ) + + # Monkeypatch tqdm to suppress terminal output and route through our queue. + # This prevents tqdm progress bars from flickering through the TUI and + # ensures all progress events appear in the Events panel. + self._install_tqdm_hook() + + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def _install_tqdm_hook(self) -> None: + """Replace tqdm's display method to route updates through TUI queue.""" + try: + import tqdm + import tqdm.auto + import io + + q = self._queue + parser = self._tqdm_parser = None + # Find our tqdm parser in the chain + for p in (self._parser_chain._parsers if self._parser_chain else []): + if p.name == "tqdm": + self._tqdm_parser = p + break + + # Save originals for restore + self._orig_tqdm_init = tqdm.tqdm.__init__ + self._orig_tqdm_class = tqdm.auto.tqdm + + renderer_self = self + + class TUITqdm(tqdm.tqdm): + """tqdm subclass that sends progress to TUI instead of terminal.""" + + def __init__(self, *args, **kwargs): + # Force output to devnull so nothing reaches the terminal + kwargs["file"] = io.StringIO() + kwargs["dynamic_ncols"] = False + kwargs["ncols"] = 80 + super().__init__(*args, **kwargs) + + def display(self, msg=None, pos=None): + # Build a progress string and push to queue + if self.total and self.total > 0: + pct = self.n / self.total * 100 + desc = self.desc.rstrip(": ") if self.desc else "" + elapsed = self.format_interval(self.elapsed) if hasattr(self, 'elapsed') and self.elapsed else "?" + rate = self.format_sizeof(1.0 / self.avg_time) if hasattr(self, 'avg_time') and self.avg_time else "?" + + # Emit events at milestones or at low frequency + is_milestone = ( + self.n == 0 or + self.n >= self.total or + int(pct) % 25 == 0 + ) + if is_milestone: + try: + q.put_nowait({ + "type": "log_line", + "level": "info", + "message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})" if desc else f"{pct:.0f}% ({self.n}/{self.total})", + }) + except Exception: + pass + + try: + metric_key = f"progress/{desc.lower().replace(' ', '_')}" if desc else "progress/unknown" + q.put_nowait({ + "type": "metrics", + "logs": {metric_key: pct / 100.0}, + }) + except Exception: + pass + + def close(self): + # Emit final completion event + if self.total and self.total > 0 and self.n > 0: + desc = self.desc.rstrip(": ") if self.desc else "" + try: + q.put_nowait({ + "type": "log_line", + "level": "info", + "message": f"[{desc}] 100% ({self.total}/{self.total}) done" if desc else f"100% ({self.total}/{self.total}) done", + }) + except Exception: + pass + super().close() + + # Replace tqdm globally + tqdm.auto.tqdm = TUITqdm + tqdm.tqdm = TUITqdm + # Also patch tqdm.std which some libraries use directly + tqdm.std.tqdm = TUITqdm + self._tui_tqdm_cls = TUITqdm + + except Exception as exc: + LOG.debug(f"Failed to install tqdm hook: {exc}") + + def _uninstall_tqdm_hook(self) -> None: + """Restore original tqdm.""" + try: + import tqdm + import tqdm.auto + + if hasattr(self, "_orig_tqdm_class"): + tqdm.auto.tqdm = self._orig_tqdm_class + if hasattr(self, "_orig_tqdm_init"): + tqdm.tqdm.__init__ = self._orig_tqdm_init + tqdm.std.tqdm = tqdm.tqdm + except Exception: + pass + + def stop(self) -> None: + self._stop_event.set() + self._uninstall_tqdm_hook() + if self._thread is not None: + self._thread.join(timeout=5.0) + + def _run(self) -> None: + import os + + # Save a handle to the REAL terminal BEFORE IO capture redirects fds. + # This ensures rich.live.Live writes to the terminal, not the pipe. + saved_tty_fd = os.dup(1) + tty_file = os.fdopen(saved_tty_fd, "w", buffering=1, closefd=True) + console = Console(file=tty_file) + + layout = self._build_layout() + tick_interval = 1.0 / max(self._config.refresh_rate, 1) + gpu_poll_counter = 0 + gpu_poll_ticks = max( + 1, int(self._config.hardware_poll_interval / tick_interval) + ) + + # Start I/O capture — redirects fd 1/2 to pipe AFTER we saved the tty fd + if self._io_capture: + self._io_capture.start() + + try: + with Live( + layout, + console=console, + refresh_per_second=self._config.refresh_rate, + screen=True, + redirect_stdout=False, + redirect_stderr=False, + ) as live: + while not self._stop_event.is_set(): + self._drain_queue() + + # Poll GPU stats periodically + gpu_poll_counter += 1 + if gpu_poll_counter >= gpu_poll_ticks: + gpu_poll_counter = 0 + if self._gpu_poller.available: + self._state.gpus = self._gpu_poller.poll() + + # Update elapsed time + self._state.elapsed_seconds = ( + time.time() - self._state.start_time.timestamp() + ) + + self._update_layout(layout) + live.update(layout) + + time.sleep(tick_interval) + + # Final drain + self._drain_queue() + self._update_layout(layout) + live.update(layout) + finally: + if self._io_capture: + self._io_capture.stop() + try: + tty_file.close() + except Exception: + pass diff --git a/src/axolotl/tui/state.py b/src/axolotl/tui/state.py new file mode 100644 index 000000000..f67445570 --- /dev/null +++ b/src/axolotl/tui/state.py @@ -0,0 +1,85 @@ +"""TUI shared data model — dataclasses for the dashboard state.""" + +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + + +@dataclass +class GPUStats: + id: int + name: str + util_pct: float + vram_used_gb: float + vram_total_gb: float + temp_c: int + power_w: float | None + + +@dataclass +class LogLine: + timestamp: datetime + level: str # "info" | "debug" | "warning" | "error" + message: str + + +@dataclass +class CompletionSample: + step: int + prompt: str + completion: str + reward: float | None + advantage: float | None + + +@dataclass +class TUIState: + # Run metadata + run_name: str = "" + model_name: str = "" + training_mode: str = "sft" + world_size: int = 1 + start_time: datetime = field(default_factory=datetime.now) + + # Progress + current_step: int = 0 + total_steps: int = 0 + current_epoch: float = 0.0 + total_epochs: int = 1 + elapsed_seconds: float = 0.0 + eta_seconds: float | None = None + + # Training metrics (rolling window + current) + loss: float | None = None + grad_norm: float | None = None + learning_rate: float | None = None + tokens_per_second: float | None = None + samples_per_second: float | None = None + mfu: float | None = None + + # RL-specific (None for non-RL modes) + rewards_mean: float | None = None + rewards_std: float | None = None + kl_divergence: float | None = None + clip_ratio: float | None = None + queue_size: int | None = None + + # Per-GPU hardware (list indexed by local rank) + gpus: list[GPUStats] = field(default_factory=list) + + # Recent log lines + log_lines: deque[LogLine] = field(default_factory=lambda: deque(maxlen=200)) + + # Recent completions (GRPO/SFT with log_completions) + completions: deque[CompletionSample] = field( + default_factory=lambda: deque(maxlen=20) + ) + + # Loss history for sparkline + loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50)) + + # Arbitrary plugin state + extra: dict[str, Any] = field(default_factory=dict) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 2f269b78e..930ee4c7f 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -52,6 +52,8 @@ from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig +from axolotl.tui.config import TUIConfig + LOG = get_logger(__name__) @@ -140,6 +142,12 @@ class AxolotlInputConfig( vllm: VllmConfig | None = Field( default_factory=lambda: VllmConfig(), ) + tui: TUIConfig | None = Field( + default=None, + json_schema_extra={ + "description": "TUI dashboard configuration. Set enabled: true to activate." + }, + ) qat: QATConfig | None = None quantization: PTQConfig | None = None reward_model: bool | None = Field(