chore: lint

This commit is contained in:
Wing Lian
2026-03-19 00:00:11 -04:00
committed by Wing Lian
parent 35d06c8087
commit db6af43f3b
16 changed files with 189 additions and 137 deletions

View File

@@ -2,6 +2,7 @@
import gc import gc
import os import os
import queue
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -36,7 +37,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
# Start TUI early (before data loading) so it captures preprocessing events # Start TUI early (before data loading) so it captures preprocessing events
tui_renderer = None tui_renderer = None
tui_queue = None tui_queue: queue.Queue | None = None
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0 is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
if is_rank_0: if is_rank_0:
from axolotl.train import _is_tui_enabled from axolotl.train import _is_tui_enabled
@@ -44,12 +45,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if _is_tui_enabled(cfg): if _is_tui_enabled(cfg):
import queue as _queue import queue as _queue
from axolotl.train import _get_tui_config
from axolotl.tui.config import TUIConfig from axolotl.tui.config import TUIConfig
from axolotl.tui.renderer import TUIRenderer from axolotl.tui.renderer import TUIRenderer
from axolotl.train import _get_tui_config
tui_config_dict = _get_tui_config(cfg) tui_config_dict = _get_tui_config(cfg)
tui_config = TUIConfig(**tui_config_dict) if isinstance(tui_config_dict, dict) else tui_config_dict tui_config = (
TUIConfig(**tui_config_dict)
if isinstance(tui_config_dict, dict)
else tui_config_dict
)
tui_queue = _queue.Queue(maxsize=4096) tui_queue = _queue.Queue(maxsize=4096)
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue) tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
@@ -58,12 +63,14 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
training_mode = str(cfg.rl) if cfg.rl else "sft" training_mode = str(cfg.rl) if cfg.rl else "sft"
world_size = int(os.environ.get("WORLD_SIZE", 1)) world_size = int(os.environ.get("WORLD_SIZE", 1))
try: try:
tui_queue.put_nowait({ tui_queue.put_nowait(
"type": "run_info", {
"model_name": model_name, "type": "run_info",
"training_mode": training_mode, "model_name": model_name,
"world_size": world_size, "training_mode": training_mode,
}) "world_size": world_size,
}
)
except _queue.Full: except _queue.Full:
pass pass
@@ -74,7 +81,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
from axolotl.tui.callback import _TUILogHandler from axolotl.tui.callback import _TUILogHandler
_early_log_handler = _TUILogHandler(tui_queue, min_level=tui_config.log_level) _early_log_handler = _TUILogHandler(
tui_queue, min_level=tui_config.log_level
)
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s")) _early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
# Attach to BOTH root and axolotl loggers because axolotl logger # Attach to BOTH root and axolotl loggers because axolotl logger
# has propagate=False so root handler never sees axolotl.* messages # has propagate=False so root handler never sees axolotl.* messages
@@ -110,8 +119,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
# (e.g., error during data loading), clean up here # (e.g., error during data loading), clean up here
if tui_renderer is not None and not tui_renderer._stop_event.is_set(): if tui_renderer is not None and not tui_renderer._stop_event.is_set():
try: try:
tui_queue.put_nowait({"type": "done"}) if tui_queue is not None:
except Exception: tui_queue.put_nowait({"type": "done"})
except queue.Full:
pass pass
tui_renderer.stop() tui_renderer.stop()
# Remove early log handler from both root and axolotl loggers # Remove early log handler from both root and axolotl loggers

View File

@@ -9,7 +9,6 @@ import os
import shutil import shutil
import signal import signal
import sys import sys
import typing
import weakref import weakref
from collections import OrderedDict from collections import OrderedDict
from contextlib import ExitStack from contextlib import ExitStack
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.train import determine_last_checkpoint
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
if typing.TYPE_CHECKING:
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__) LOG = get_logger(__name__)
TELEMETRY_MANAGER = TelemetryManager.get_instance() TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer( def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[ ) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder", Trainer,
PeftModel | PreTrainedModel, PeftModel | PreTrainedModel,
PreTrainedTokenizer, PreTrainedTokenizer,
PeftConfig | None, PeftConfig | None,
@@ -614,7 +610,9 @@ def train(
from axolotl.tui.config import TUIConfig from axolotl.tui.config import TUIConfig
tui_config = _get_tui_config(cfg) tui_config = _get_tui_config(cfg)
tui_config_obj = TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config tui_config_obj = (
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
)
# Reuse the early-started renderer if available (started in do_train) # Reuse the early-started renderer if available (started in do_train)
early_renderer = getattr(cfg, "_tui_renderer", None) early_renderer = getattr(cfg, "_tui_renderer", None)
@@ -628,16 +626,12 @@ def train(
tui_callback._renderer_started_early = True tui_callback._renderer_started_early = True
trainer.add_callback(tui_callback) trainer.add_callback(tui_callback)
# Send model info to the callback # Stash model info so on_train_begin can emit a single unified run_info event
model_name = cfg.base_model or "" tui_callback._pending_run_info = {
training_mode = str(cfg.rl) if cfg.rl else "sft" "model_name": cfg.base_model or "",
world_size = int(os.environ.get("WORLD_SIZE", 1)) "training_mode": str(cfg.rl) if cfg.rl else "sft",
tui_callback._put({ "world_size": int(os.environ.get("WORLD_SIZE", 1)),
"type": "run_info", }
"model_name": model_name,
"training_mode": training_mode,
"world_size": world_size,
})
LOG.info("TUI dashboard enabled") LOG.info("TUI dashboard enabled")
# Handle untrained tokens if configured # Handle untrained tokens if configured

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import logging import logging
import queue import queue
from typing import Any
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
@@ -33,11 +32,13 @@ class _TUILogHandler(logging.Handler):
try: try:
level = self._LEVEL_MAP.get(record.levelno, "info") level = self._LEVEL_MAP.get(record.levelno, "info")
msg = self.format(record) msg = self.format(record)
self._queue.put_nowait({ self._queue.put_nowait(
"type": "log_line", {
"level": level, "type": "log_line",
"message": msg, "level": level,
}) "message": msg,
}
)
except queue.Full: except queue.Full:
pass pass
except Exception: except Exception:
@@ -57,6 +58,7 @@ class AxolotlTUICallback(TrainerCallback):
self._renderer = TUIRenderer(config=config, metric_queue=self._queue) self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
self._log_handler: _TUILogHandler | None = None self._log_handler: _TUILogHandler | None = None
self._renderer_started_early: bool = False self._renderer_started_early: bool = False
self._pending_run_info: dict | None = None
def _put(self, event: dict) -> None: def _put(self, event: dict) -> None:
try: try:
@@ -65,25 +67,27 @@ class AxolotlTUICallback(TrainerCallback):
pass pass
def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_begin(self, args, state, control, model=None, **kwargs):
# Send run info # Send a single unified run_info event with all fields
run_name = getattr(args, "run_name", "") or "" run_info = {
self._put( "type": "run_info",
{ "run_name": getattr(args, "run_name", "") or "",
"type": "run_info", "total_steps": state.max_steps,
"run_name": run_name, "total_epochs": float(args.num_train_epochs)
"total_steps": state.max_steps, if args.num_train_epochs
"total_epochs": int(args.num_train_epochs) if args.num_train_epochs else 1, else 1.0,
} }
) # Merge in model_name/training_mode/world_size if stashed by train.py
if self._pending_run_info:
run_info.update(self._pending_run_info)
self._pending_run_info = None
self._put(run_info)
if not self._renderer_started_early: if not self._renderer_started_early:
# Attach a logging handler to feed log messages into the events panel # Attach a logging handler to feed log messages into the events panel
self._log_handler = _TUILogHandler( self._log_handler = _TUILogHandler(
self._queue, min_level=self._config.log_level self._queue, min_level=self._config.log_level
) )
self._log_handler.setFormatter( self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
logging.Formatter("[%(name)s] %(message)s")
)
# Attach to both root and axolotl loggers (axolotl has propagate=False) # Attach to both root and axolotl loggers (axolotl has propagate=False)
logging.getLogger().addHandler(self._log_handler) logging.getLogger().addHandler(self._log_handler)
logging.getLogger("axolotl").addHandler(self._log_handler) logging.getLogger("axolotl").addHandler(self._log_handler)

View File

@@ -18,15 +18,11 @@ class TUIConfig(BaseModel):
) )
log_level: str = Field( log_level: str = Field(
default="debug", default="debug",
json_schema_extra={ json_schema_extra={"description": "Minimum log level shown in events panel"},
"description": "Minimum log level shown in events panel"
},
) )
panels: list[str] = Field( panels: list[str] = Field(
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"], default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
json_schema_extra={ json_schema_extra={"description": "Ordered list of panels to display"},
"description": "Ordered list of panels to display"
},
) )
hardware_poll_interval: int = Field( hardware_poll_interval: int = Field(
default=2, default=2,
@@ -34,13 +30,9 @@ class TUIConfig(BaseModel):
) )
stdout_log_path: str = Field( stdout_log_path: str = Field(
default="axolotl_stdout.log", default="axolotl_stdout.log",
json_schema_extra={ json_schema_extra={"description": "File path for captured stdout/stderr log"},
"description": "File path for captured stdout/stderr log"
},
) )
parser_plugins: list[str] = Field( parser_plugins: list[str] = Field(
default_factory=list, default_factory=list,
json_schema_extra={ json_schema_extra={"description": "List of extra parser classes to load"},
"description": "List of extra parser classes to load"
},
) )

View File

@@ -68,5 +68,5 @@ class GPUPoller:
) )
) )
except Exception: except Exception:
pass LOG.debug("Error polling GPU device %d", i, exc_info=True)
return stats return stats

View File

@@ -2,17 +2,14 @@
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import queue import queue
import re
import sys import sys
import threading import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import IO
if TYPE_CHECKING:
pass
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Parser registry # Parser registry
@@ -23,7 +20,8 @@ _parser_registry: list[type[LineParser]] = []
def register_parser(cls: type[LineParser]) -> type[LineParser]: def register_parser(cls: type[LineParser]) -> type[LineParser]:
"""Decorator to register a LineParser subclass.""" """Decorator to register a LineParser subclass."""
_parser_registry.append(cls) if cls not in _parser_registry:
_parser_registry.append(cls)
return cls return cls
@@ -92,7 +90,7 @@ class IOCapture:
self._parser_chain = parser_chain self._parser_chain = parser_chain
self._queue = metric_queue self._queue = metric_queue
self._log_path = log_path self._log_path = log_path
self._log_file = None self._log_file: IO[str] | None = None
self._thread: threading.Thread | None = None self._thread: threading.Thread | None = None
self._read_fd: int | None = None self._read_fd: int | None = None
self._write_fd: int | None = None self._write_fd: int | None = None
@@ -129,7 +127,7 @@ class IOCapture:
def stop(self) -> None: def stop(self) -> None:
# Restore fds — closes the write end, causing reader to see EOF # Restore fds — closes the write end, causing reader to see EOF
if self._saved_stdout_fd is not None: if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__ sys.stderr = sys.__stderr__
os.dup2(self._saved_stdout_fd, 1) os.dup2(self._saved_stdout_fd, 1)
@@ -141,6 +139,10 @@ class IOCapture:
if self._thread is not None: if self._thread is not None:
self._thread.join(timeout=2.0) self._thread.join(timeout=2.0)
if self._thread.is_alive():
logging.getLogger(__name__).warning(
"IO capture thread did not exit after 2s"
)
self._thread = None self._thread = None
if self._log_file is not None: if self._log_file is not None:
@@ -150,6 +152,7 @@ class IOCapture:
def _drain(self) -> None: def _drain(self) -> None:
# Read raw bytes and split on both \n and \r to handle tqdm progress bars # Read raw bytes and split on both \n and \r to handle tqdm progress bars
# which use \r for in-place updates without \n # which use \r for in-place updates without \n
assert self._read_fd is not None, "_drain called before start()"
with os.fdopen(self._read_fd, "rb") as pipe: with os.fdopen(self._read_fd, "rb") as pipe:
buf = b"" buf = b""
while True: while True:

View File

@@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any
from rich.console import RenderableType from rich.console import RenderableType
@@ -50,7 +49,7 @@ class BasePanel(ABC):
"""Return a rich renderable. Called every tick.""" """Return a rich renderable. Called every tick."""
... ...
def on_event(self, event: dict) -> None: def on_event(self, event: dict) -> None: # noqa: B027
"""Optional: react to raw metric events before state is merged.""" """Optional: react to raw metric events before state is merged."""
pass pass

View File

@@ -22,6 +22,9 @@ class CompletionsPanel(BasePanel):
modes = ["grpo", "dpo"] modes = ["grpo", "dpo"]
def render(self, state: TUIState) -> RenderableType: def render(self, state: TUIState) -> RenderableType:
if "*" not in self.modes and state.training_mode not in self.modes:
return Text("")
if not state.completions: if not state.completions:
return Panel( return Panel(
Text("No completions yet...", style="dim"), Text("No completions yet...", style="dim"),

View File

@@ -2,15 +2,12 @@
from __future__ import annotations from __future__ import annotations
from collections import deque
from datetime import datetime
from rich.console import RenderableType from rich.console import RenderableType
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from axolotl.tui.panels import BasePanel, register_panel from axolotl.tui.panels import BasePanel, register_panel
from axolotl.tui.state import LogLine, TUIState from axolotl.tui.state import TUIState
@register_panel(position="bottom", weight=30) @register_panel(position="bottom", weight=30)
@@ -22,7 +19,9 @@ class DebugPanel(BasePanel):
def render(self, state: TUIState) -> RenderableType: def render(self, state: TUIState) -> RenderableType:
lines = Text() lines = Text()
# Show last 8 debug-level log lines # Show last 8 debug-level log lines
debug_lines = [l for l in state.log_lines if l.level == "debug"][-8:] debug_lines = [
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
][-8:]
for log_line in debug_lines: for log_line in debug_lines:
ts = log_line.timestamp.strftime("%H:%M:%S") ts = log_line.timestamp.strftime("%H:%M:%S")
lines.append(f"[{ts}] ", style="dim") lines.append(f"[{ts}] ", style="dim")

View File

@@ -27,7 +27,9 @@ class EventsPanel(BasePanel):
def render(self, state: TUIState) -> RenderableType: def render(self, state: TUIState) -> RenderableType:
lines = Text() lines = Text()
# Show last 15 non-debug log lines (debug goes to DebugPanel) # Show last 15 non-debug log lines (debug goes to DebugPanel)
recent = [l for l in state.log_lines if l.level != "debug"][-15:] recent = [
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
][-15:]
for log_line in recent: for log_line in recent:
ts = log_line.timestamp.strftime("%H:%M:%S") ts = log_line.timestamp.strftime("%H:%M:%S")
level = log_line.level.upper() level = log_line.level.upper()

View File

@@ -3,5 +3,5 @@
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401 from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401 from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401 from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401 from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401

View File

@@ -14,8 +14,8 @@ class RawLogParser(LineParser):
_LOG_RE = re.compile( _LOG_RE = re.compile(
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)" r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)" r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
r"\s*[-]\s*(?P<msg>.+)$", r"\s*[-]\s*(?P<msg>.+)$",
re.IGNORECASE, re.IGNORECASE,
) )

View File

@@ -25,9 +25,7 @@ class TqdmParser(LineParser):
# Also match simpler forms like: # Also match simpler forms like:
# Fetching 0 files: 0it [00:00, ?it/s] # Fetching 0 files: 0it [00:00, ?it/s]
_FETCH_RE = re.compile( _FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]"
)
def parse(self, line: str, source: str) -> list[dict]: def parse(self, line: str, source: str) -> list[dict]:
m = self._TQDM_RE.search(line) m = self._TQDM_RE.search(line)
@@ -41,21 +39,48 @@ class TqdmParser(LineParser):
# Surface as a log line with progress info # Surface as a log line with progress info
if pct == 100 or pct == 0 or pct % 25 == 0: if pct == 100 or pct == 0 or pct % 25 == 0:
msg = f"[{desc}] {pct}% ({current}/{total})" if desc else f"{pct}% ({current}/{total})" msg = (
events.append({ f"[{desc}] {pct}% ({current}/{total})"
"type": "log_line", if desc
"level": "info", else f"{pct}% ({current}/{total})"
"message": msg, )
}) events.append(
{
"type": "log_line",
"level": "info",
"message": msg,
}
)
# Also emit as a progress metric # Also emit as a progress metric
events.append({ cleaned_desc = desc.strip().lower().replace(" ", "_")
"type": "metrics", if not cleaned_desc:
"logs": { cleaned_desc = "progress"
f"progress/{desc.lower().replace(' ', '_')}": pct / 100.0, events.append(
}, {
}) "type": "metrics",
"logs": {
f"progress/{cleaned_desc}": pct / 100.0,
},
}
)
return events return events
# Fallback: try simpler fetch-style progress lines
m = self._FETCH_RE.search(line)
if m:
desc = m.group("desc").strip().rstrip(":")
current = int(m.group("current"))
cleaned_desc = desc.strip().lower().replace(" ", "_")
if not cleaned_desc:
cleaned_desc = "fetch"
return [
{
"type": "log_line",
"level": "info",
"message": f"[{desc}] {current}" if desc else f"{current}",
}
]
return [] return []

View File

@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import importlib
import logging import logging
import queue import queue
import threading import threading
@@ -13,7 +12,6 @@ from typing import Any
from rich.console import Console from rich.console import Console
from rich.layout import Layout from rich.layout import Layout
from rich.live import Live from rich.live import Live
from rich.text import Text
from axolotl.tui.config import TUIConfig from axolotl.tui.config import TUIConfig
from axolotl.tui.gpu import GPUPoller from axolotl.tui.gpu import GPUPoller
@@ -49,6 +47,9 @@ class TUIRenderer:
self._panels.append(registry[panel_name]()) self._panels.append(registry[panel_name]())
def _init_parser_chain(self) -> None: def _init_parser_chain(self) -> None:
# Ensure built-in parsers are imported so @register_parser decorators fire
import axolotl.tui.parsers # noqa: F401
self._parser_chain = ParserChain() self._parser_chain = ParserChain()
# Register all built-in parsers # Register all built-in parsers
for parser_cls in get_registered_parsers(): for parser_cls in get_registered_parsers():
@@ -65,6 +66,8 @@ class TUIRenderer:
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
"custom_parser", file_path "custom_parser", file_path
) )
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec for {file_path}")
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)
parser_cls = getattr(mod, class_name) parser_cls = getattr(mod, class_name)
@@ -167,10 +170,7 @@ class TUIRenderer:
) )
now = time.time() now = time.time()
self._state.elapsed_seconds = now - self._state.start_time.timestamp() self._state.elapsed_seconds = now - self._state.start_time.timestamp()
if ( if self._state.current_step > 0 and self._state.total_steps > 0:
self._state.current_step > 0
and self._state.total_steps > 0
):
rate = self._state.elapsed_seconds / self._state.current_step rate = self._state.elapsed_seconds / self._state.current_step
remaining = self._state.total_steps - self._state.current_step remaining = self._state.total_steps - self._state.current_step
self._state.eta_seconds = rate * remaining self._state.eta_seconds = rate * remaining
@@ -210,6 +210,8 @@ class TUIRenderer:
self._state.total_steps = event["total_steps"] self._state.total_steps = event["total_steps"]
if "total_epochs" in event: if "total_epochs" in event:
self._state.total_epochs = event["total_epochs"] self._state.total_epochs = event["total_epochs"]
if "zero_stage" in event:
self._state.zero_stage = event["zero_stage"]
elif event_type == "done": elif event_type == "done":
self._stop_event.set() self._stop_event.set()
@@ -246,6 +248,7 @@ class TUIRenderer:
self._init_parser_chain() self._init_parser_chain()
# Set up I/O capture # Set up I/O capture
assert self._parser_chain is not None, "_init_parser_chain must be called first"
self._io_capture = IOCapture( self._io_capture = IOCapture(
log_path=self._config.stdout_log_path, log_path=self._config.stdout_log_path,
parser_chain=self._parser_chain, parser_chain=self._parser_chain,
@@ -257,29 +260,31 @@ class TUIRenderer:
# ensures all progress events appear in the Events panel. # ensures all progress events appear in the Events panel.
self._install_tqdm_hook() self._install_tqdm_hook()
self._io_capture_ready = threading.Event()
self._thread = threading.Thread(target=self._run, daemon=True) self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start() self._thread.start()
self._io_capture_ready.wait(timeout=5.0)
def _install_tqdm_hook(self) -> None: def _install_tqdm_hook(self) -> None:
"""Replace tqdm's display method to route updates through TUI queue.""" """Replace tqdm's display method to route updates through TUI queue."""
try: try:
import tqdm
import tqdm.auto
import io import io
import tqdm
import tqdm.auto
q = self._queue q = self._queue
parser = self._tqdm_parser = None self._tqdm_parser = None
# Find our tqdm parser in the chain # Find our tqdm parser in the chain
for p in (self._parser_chain._parsers if self._parser_chain else []): for p in self._parser_chain._parsers if self._parser_chain else []:
if p.name == "tqdm": if p.name == "tqdm":
self._tqdm_parser = p self._tqdm_parser = p
break break
# Save originals for restore # Save originals for restore
self._orig_tqdm_init = tqdm.tqdm.__init__ self._orig_tqdm_class_auto = tqdm.auto.tqdm
self._orig_tqdm_class = tqdm.auto.tqdm self._orig_tqdm_class_tqdm = tqdm.tqdm
self._orig_tqdm_class_std = tqdm.std.tqdm
renderer_self = self
class TUITqdm(tqdm.tqdm): class TUITqdm(tqdm.tqdm):
"""tqdm subclass that sends progress to TUI instead of terminal.""" """tqdm subclass that sends progress to TUI instead of terminal."""
@@ -296,31 +301,36 @@ class TUIRenderer:
if self.total and self.total > 0: if self.total and self.total > 0:
pct = self.n / self.total * 100 pct = self.n / self.total * 100
desc = self.desc.rstrip(": ") if self.desc else "" desc = self.desc.rstrip(": ") if self.desc else ""
elapsed = self.format_interval(self.elapsed) if hasattr(self, 'elapsed') and self.elapsed else "?"
rate = self.format_sizeof(1.0 / self.avg_time) if hasattr(self, 'avg_time') and self.avg_time else "?"
# Emit events at milestones or at low frequency # Emit events at milestones or at low frequency
is_milestone = ( is_milestone = (
self.n == 0 or self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
self.n >= self.total or
int(pct) % 25 == 0
) )
if is_milestone: if is_milestone:
try: try:
q.put_nowait({ q.put_nowait(
"type": "log_line", {
"level": "info", "type": "log_line",
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})" if desc else f"{pct:.0f}% ({self.n}/{self.total})", "level": "info",
}) "message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
if desc
else f"{pct:.0f}% ({self.n}/{self.total})",
}
)
except Exception: except Exception:
pass pass
try: try:
metric_key = f"progress/{desc.lower().replace(' ', '_')}" if desc else "progress/unknown" metric_key = (
q.put_nowait({ f"progress/{desc.lower().replace(' ', '_')}"
"type": "metrics", if desc
"logs": {metric_key: pct / 100.0}, else "progress/unknown"
}) )
q.put_nowait(
{
"type": "metrics",
"logs": {metric_key: pct / 100.0},
}
)
except Exception: except Exception:
pass pass
@@ -329,11 +339,15 @@ class TUIRenderer:
if self.total and self.total > 0 and self.n > 0: if self.total and self.total > 0 and self.n > 0:
desc = self.desc.rstrip(": ") if self.desc else "" desc = self.desc.rstrip(": ") if self.desc else ""
try: try:
q.put_nowait({ q.put_nowait(
"type": "log_line", {
"level": "info", "type": "log_line",
"message": f"[{desc}] 100% ({self.total}/{self.total}) done" if desc else f"100% ({self.total}/{self.total}) done", "level": "info",
}) "message": f"[{desc}] 100% ({self.total}/{self.total}) done"
if desc
else f"100% ({self.total}/{self.total}) done",
}
)
except Exception: except Exception:
pass pass
super().close() super().close()
@@ -354,11 +368,12 @@ class TUIRenderer:
import tqdm import tqdm
import tqdm.auto import tqdm.auto
if hasattr(self, "_orig_tqdm_class"): if hasattr(self, "_orig_tqdm_class_auto"):
tqdm.auto.tqdm = self._orig_tqdm_class tqdm.auto.tqdm = self._orig_tqdm_class_auto
if hasattr(self, "_orig_tqdm_init"): if hasattr(self, "_orig_tqdm_class_tqdm"):
tqdm.tqdm.__init__ = self._orig_tqdm_init tqdm.tqdm = self._orig_tqdm_class_tqdm
tqdm.std.tqdm = tqdm.tqdm if hasattr(self, "_orig_tqdm_class_std"):
tqdm.std.tqdm = self._orig_tqdm_class_std
except Exception: except Exception:
pass pass
@@ -388,6 +403,10 @@ class TUIRenderer:
if self._io_capture: if self._io_capture:
self._io_capture.start() self._io_capture.start()
# Signal that IO capture is live so start() can return
if hasattr(self, "_io_capture_ready"):
self._io_capture_ready.set()
try: try:
with Live( with Live(
layout, layout,

View File

@@ -48,7 +48,7 @@ class TUIState:
current_step: int = 0 current_step: int = 0
total_steps: int = 0 total_steps: int = 0
current_epoch: float = 0.0 current_epoch: float = 0.0
total_epochs: int = 1 total_epochs: float = 1.0
elapsed_seconds: float = 0.0 elapsed_seconds: float = 0.0
eta_seconds: float | None = None eta_seconds: float | None = None
@@ -81,5 +81,8 @@ class TUIState:
# Loss history for sparkline # Loss history for sparkline
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50)) loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
# DeepSpeed zero stage (None if not using DeepSpeed)
zero_stage: int | None = None
# Arbitrary plugin state # Arbitrary plugin state
extra: dict[str, Any] = field(default_factory=dict) extra: dict[str, Any] = field(default_factory=dict)

View File

@@ -13,6 +13,7 @@ from pydantic import (
model_validator, model_validator,
) )
from axolotl.tui.config import TUIConfig
from axolotl.utils.datasets import get_default_process_count from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
@@ -52,8 +53,6 @@ from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.validation import ValidationMixin
from axolotl.utils.schemas.vllm import VllmConfig from axolotl.utils.schemas.vllm import VllmConfig
from axolotl.tui.config import TUIConfig
LOG = get_logger(__name__) LOG = get_logger(__name__)