chore: lint
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
import queue
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -36,7 +37,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
|
||||
# Start TUI early (before data loading) so it captures preprocessing events
|
||||
tui_renderer = None
|
||||
tui_queue = None
|
||||
tui_queue: queue.Queue | None = None
|
||||
is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0
|
||||
if is_rank_0:
|
||||
from axolotl.train import _is_tui_enabled
|
||||
@@ -44,12 +45,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
if _is_tui_enabled(cfg):
|
||||
import queue as _queue
|
||||
|
||||
from axolotl.train import _get_tui_config
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.renderer import TUIRenderer
|
||||
from axolotl.train import _get_tui_config
|
||||
|
||||
tui_config_dict = _get_tui_config(cfg)
|
||||
tui_config = TUIConfig(**tui_config_dict) if isinstance(tui_config_dict, dict) else tui_config_dict
|
||||
tui_config = (
|
||||
TUIConfig(**tui_config_dict)
|
||||
if isinstance(tui_config_dict, dict)
|
||||
else tui_config_dict
|
||||
)
|
||||
tui_queue = _queue.Queue(maxsize=4096)
|
||||
tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue)
|
||||
|
||||
@@ -58,12 +63,14 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
try:
|
||||
tui_queue.put_nowait({
|
||||
tui_queue.put_nowait(
|
||||
{
|
||||
"type": "run_info",
|
||||
"model_name": model_name,
|
||||
"training_mode": training_mode,
|
||||
"world_size": world_size,
|
||||
})
|
||||
}
|
||||
)
|
||||
except _queue.Full:
|
||||
pass
|
||||
|
||||
@@ -74,7 +81,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
|
||||
from axolotl.tui.callback import _TUILogHandler
|
||||
|
||||
_early_log_handler = _TUILogHandler(tui_queue, min_level=tui_config.log_level)
|
||||
_early_log_handler = _TUILogHandler(
|
||||
tui_queue, min_level=tui_config.log_level
|
||||
)
|
||||
_early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||
# Attach to BOTH root and axolotl loggers because axolotl logger
|
||||
# has propagate=False so root handler never sees axolotl.* messages
|
||||
@@ -110,8 +119,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
# (e.g., error during data loading), clean up here
|
||||
if tui_renderer is not None and not tui_renderer._stop_event.is_set():
|
||||
try:
|
||||
if tui_queue is not None:
|
||||
tui_queue.put_nowait({"type": "done"})
|
||||
except Exception:
|
||||
except queue.Full:
|
||||
pass
|
||||
tui_renderer.stop()
|
||||
# Remove early log handler from both root and axolotl loggers
|
||||
|
||||
@@ -9,7 +9,6 @@ import os
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import typing
|
||||
import weakref
|
||||
from collections import OrderedDict
|
||||
from contextlib import ExitStack
|
||||
@@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.train import determine_last_checkpoint
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -487,7 +483,7 @@ def handle_untrained_tokens_fix(
|
||||
def setup_model_and_trainer(
|
||||
cfg: DictDefault, dataset_meta: TrainDatasetMeta
|
||||
) -> tuple[
|
||||
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||
Trainer,
|
||||
PeftModel | PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
PeftConfig | None,
|
||||
@@ -614,7 +610,9 @@ def train(
|
||||
from axolotl.tui.config import TUIConfig
|
||||
|
||||
tui_config = _get_tui_config(cfg)
|
||||
tui_config_obj = TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
||||
tui_config_obj = (
|
||||
TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config
|
||||
)
|
||||
|
||||
# Reuse the early-started renderer if available (started in do_train)
|
||||
early_renderer = getattr(cfg, "_tui_renderer", None)
|
||||
@@ -628,16 +626,12 @@ def train(
|
||||
tui_callback._renderer_started_early = True
|
||||
trainer.add_callback(tui_callback)
|
||||
|
||||
# Send model info to the callback
|
||||
model_name = cfg.base_model or ""
|
||||
training_mode = str(cfg.rl) if cfg.rl else "sft"
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
tui_callback._put({
|
||||
"type": "run_info",
|
||||
"model_name": model_name,
|
||||
"training_mode": training_mode,
|
||||
"world_size": world_size,
|
||||
})
|
||||
# Stash model info so on_train_begin can emit a single unified run_info event
|
||||
tui_callback._pending_run_info = {
|
||||
"model_name": cfg.base_model or "",
|
||||
"training_mode": str(cfg.rl) if cfg.rl else "sft",
|
||||
"world_size": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
}
|
||||
LOG.info("TUI dashboard enabled")
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
from typing import Any
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
@@ -33,11 +32,13 @@ class _TUILogHandler(logging.Handler):
|
||||
try:
|
||||
level = self._LEVEL_MAP.get(record.levelno, "info")
|
||||
msg = self.format(record)
|
||||
self._queue.put_nowait({
|
||||
self._queue.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": level,
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
)
|
||||
except queue.Full:
|
||||
pass
|
||||
except Exception:
|
||||
@@ -57,6 +58,7 @@ class AxolotlTUICallback(TrainerCallback):
|
||||
self._renderer = TUIRenderer(config=config, metric_queue=self._queue)
|
||||
self._log_handler: _TUILogHandler | None = None
|
||||
self._renderer_started_early: bool = False
|
||||
self._pending_run_info: dict | None = None
|
||||
|
||||
def _put(self, event: dict) -> None:
|
||||
try:
|
||||
@@ -65,25 +67,27 @@ class AxolotlTUICallback(TrainerCallback):
|
||||
pass
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
# Send run info
|
||||
run_name = getattr(args, "run_name", "") or ""
|
||||
self._put(
|
||||
{
|
||||
# Send a single unified run_info event with all fields
|
||||
run_info = {
|
||||
"type": "run_info",
|
||||
"run_name": run_name,
|
||||
"run_name": getattr(args, "run_name", "") or "",
|
||||
"total_steps": state.max_steps,
|
||||
"total_epochs": int(args.num_train_epochs) if args.num_train_epochs else 1,
|
||||
"total_epochs": float(args.num_train_epochs)
|
||||
if args.num_train_epochs
|
||||
else 1.0,
|
||||
}
|
||||
)
|
||||
# Merge in model_name/training_mode/world_size if stashed by train.py
|
||||
if self._pending_run_info:
|
||||
run_info.update(self._pending_run_info)
|
||||
self._pending_run_info = None
|
||||
self._put(run_info)
|
||||
|
||||
if not self._renderer_started_early:
|
||||
# Attach a logging handler to feed log messages into the events panel
|
||||
self._log_handler = _TUILogHandler(
|
||||
self._queue, min_level=self._config.log_level
|
||||
)
|
||||
self._log_handler.setFormatter(
|
||||
logging.Formatter("[%(name)s] %(message)s")
|
||||
)
|
||||
self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s"))
|
||||
# Attach to both root and axolotl loggers (axolotl has propagate=False)
|
||||
logging.getLogger().addHandler(self._log_handler)
|
||||
logging.getLogger("axolotl").addHandler(self._log_handler)
|
||||
|
||||
@@ -18,15 +18,11 @@ class TUIConfig(BaseModel):
|
||||
)
|
||||
log_level: str = Field(
|
||||
default="debug",
|
||||
json_schema_extra={
|
||||
"description": "Minimum log level shown in events panel"
|
||||
},
|
||||
json_schema_extra={"description": "Minimum log level shown in events panel"},
|
||||
)
|
||||
panels: list[str] = Field(
|
||||
default_factory=lambda: ["progress", "training", "hardware", "events", "debug"],
|
||||
json_schema_extra={
|
||||
"description": "Ordered list of panels to display"
|
||||
},
|
||||
json_schema_extra={"description": "Ordered list of panels to display"},
|
||||
)
|
||||
hardware_poll_interval: int = Field(
|
||||
default=2,
|
||||
@@ -34,13 +30,9 @@ class TUIConfig(BaseModel):
|
||||
)
|
||||
stdout_log_path: str = Field(
|
||||
default="axolotl_stdout.log",
|
||||
json_schema_extra={
|
||||
"description": "File path for captured stdout/stderr log"
|
||||
},
|
||||
json_schema_extra={"description": "File path for captured stdout/stderr log"},
|
||||
)
|
||||
parser_plugins: list[str] = Field(
|
||||
default_factory=list,
|
||||
json_schema_extra={
|
||||
"description": "List of extra parser classes to load"
|
||||
},
|
||||
json_schema_extra={"description": "List of extra parser classes to load"},
|
||||
)
|
||||
|
||||
@@ -68,5 +68,5 @@ class GPUPoller:
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
LOG.debug("Error polling GPU device %d", i, exc_info=True)
|
||||
return stats
|
||||
|
||||
@@ -2,17 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from typing import IO
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parser registry
|
||||
@@ -23,6 +20,7 @@ _parser_registry: list[type[LineParser]] = []
|
||||
|
||||
def register_parser(cls: type[LineParser]) -> type[LineParser]:
|
||||
"""Decorator to register a LineParser subclass."""
|
||||
if cls not in _parser_registry:
|
||||
_parser_registry.append(cls)
|
||||
return cls
|
||||
|
||||
@@ -92,7 +90,7 @@ class IOCapture:
|
||||
self._parser_chain = parser_chain
|
||||
self._queue = metric_queue
|
||||
self._log_path = log_path
|
||||
self._log_file = None
|
||||
self._log_file: IO[str] | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._read_fd: int | None = None
|
||||
self._write_fd: int | None = None
|
||||
@@ -129,7 +127,7 @@ class IOCapture:
|
||||
|
||||
def stop(self) -> None:
|
||||
# Restore fds — closes the write end, causing reader to see EOF
|
||||
if self._saved_stdout_fd is not None:
|
||||
if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None:
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
os.dup2(self._saved_stdout_fd, 1)
|
||||
@@ -141,6 +139,10 @@ class IOCapture:
|
||||
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=2.0)
|
||||
if self._thread.is_alive():
|
||||
logging.getLogger(__name__).warning(
|
||||
"IO capture thread did not exit after 2s"
|
||||
)
|
||||
self._thread = None
|
||||
|
||||
if self._log_file is not None:
|
||||
@@ -150,6 +152,7 @@ class IOCapture:
|
||||
def _drain(self) -> None:
|
||||
# Read raw bytes and split on both \n and \r to handle tqdm progress bars
|
||||
# which use \r for in-place updates without \n
|
||||
assert self._read_fd is not None, "_drain called before start()"
|
||||
with os.fdopen(self._read_fd, "rb") as pipe:
|
||||
buf = b""
|
||||
while True:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from rich.console import RenderableType
|
||||
|
||||
@@ -50,7 +49,7 @@ class BasePanel(ABC):
|
||||
"""Return a rich renderable. Called every tick."""
|
||||
...
|
||||
|
||||
def on_event(self, event: dict) -> None:
|
||||
def on_event(self, event: dict) -> None: # noqa: B027
|
||||
"""Optional: react to raw metric events before state is merged."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -22,6 +22,9 @@ class CompletionsPanel(BasePanel):
|
||||
modes = ["grpo", "dpo"]
|
||||
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
if "*" not in self.modes and state.training_mode not in self.modes:
|
||||
return Text("")
|
||||
|
||||
if not state.completions:
|
||||
return Panel(
|
||||
Text("No completions yet...", style="dim"),
|
||||
|
||||
@@ -2,15 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
from rich.console import RenderableType
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.panels import BasePanel, register_panel
|
||||
from axolotl.tui.state import LogLine, TUIState
|
||||
from axolotl.tui.state import TUIState
|
||||
|
||||
|
||||
@register_panel(position="bottom", weight=30)
|
||||
@@ -22,7 +19,9 @@ class DebugPanel(BasePanel):
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
lines = Text()
|
||||
# Show last 8 debug-level log lines
|
||||
debug_lines = [l for l in state.log_lines if l.level == "debug"][-8:]
|
||||
debug_lines = [
|
||||
log_entry for log_entry in state.log_lines if log_entry.level == "debug"
|
||||
][-8:]
|
||||
for log_line in debug_lines:
|
||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||
lines.append(f"[{ts}] ", style="dim")
|
||||
|
||||
@@ -27,7 +27,9 @@ class EventsPanel(BasePanel):
|
||||
def render(self, state: TUIState) -> RenderableType:
|
||||
lines = Text()
|
||||
# Show last 15 non-debug log lines (debug goes to DebugPanel)
|
||||
recent = [l for l in state.log_lines if l.level != "debug"][-15:]
|
||||
recent = [
|
||||
log_entry for log_entry in state.log_lines if log_entry.level != "debug"
|
||||
][-15:]
|
||||
for log_line in recent:
|
||||
ts = log_line.timestamp.strftime("%H:%M:%S")
|
||||
level = log_line.level.upper()
|
||||
|
||||
@@ -3,5 +3,5 @@
|
||||
from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401
|
||||
from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401
|
||||
from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401
|
||||
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
||||
from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401
|
||||
from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401
|
||||
|
||||
@@ -14,8 +14,8 @@ class RawLogParser(LineParser):
|
||||
|
||||
_LOG_RE = re.compile(
|
||||
r"^(?P<ts>\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)"
|
||||
r"\s*[-–]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
||||
r"\s*[-–]\s*(?P<msg>.+)$",
|
||||
r"\s*[-]\s*(?P<level>DEBUG|INFO|WARNING|ERROR|CRITICAL)"
|
||||
r"\s*[-]\s*(?P<msg>.+)$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,9 +25,7 @@ class TqdmParser(LineParser):
|
||||
|
||||
# Also match simpler forms like:
|
||||
# Fetching 0 files: 0it [00:00, ?it/s]
|
||||
_FETCH_RE = re.compile(
|
||||
r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]"
|
||||
)
|
||||
_FETCH_RE = re.compile(r"(?P<desc>[\w\s]+):\s*(?P<current>\d+)(?:it)?\s*\[.*?\]")
|
||||
|
||||
def parse(self, line: str, source: str) -> list[dict]:
|
||||
m = self._TQDM_RE.search(line)
|
||||
@@ -41,21 +39,48 @@ class TqdmParser(LineParser):
|
||||
|
||||
# Surface as a log line with progress info
|
||||
if pct == 100 or pct == 0 or pct % 25 == 0:
|
||||
msg = f"[{desc}] {pct}% ({current}/{total})" if desc else f"{pct}% ({current}/{total})"
|
||||
events.append({
|
||||
msg = (
|
||||
f"[{desc}] {pct}% ({current}/{total})"
|
||||
if desc
|
||||
else f"{pct}% ({current}/{total})"
|
||||
)
|
||||
events.append(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": msg,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# Also emit as a progress metric
|
||||
events.append({
|
||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||
if not cleaned_desc:
|
||||
cleaned_desc = "progress"
|
||||
events.append(
|
||||
{
|
||||
"type": "metrics",
|
||||
"logs": {
|
||||
f"progress/{desc.lower().replace(' ', '_')}": pct / 100.0,
|
||||
f"progress/{cleaned_desc}": pct / 100.0,
|
||||
},
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
return events
|
||||
|
||||
# Fallback: try simpler fetch-style progress lines
|
||||
m = self._FETCH_RE.search(line)
|
||||
if m:
|
||||
desc = m.group("desc").strip().rstrip(":")
|
||||
current = int(m.group("current"))
|
||||
cleaned_desc = desc.strip().lower().replace(" ", "_")
|
||||
if not cleaned_desc:
|
||||
cleaned_desc = "fetch"
|
||||
return [
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] {current}" if desc else f"{current}",
|
||||
}
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
@@ -13,7 +12,6 @@ from typing import Any
|
||||
from rich.console import Console
|
||||
from rich.layout import Layout
|
||||
from rich.live import Live
|
||||
from rich.text import Text
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.tui.gpu import GPUPoller
|
||||
@@ -49,6 +47,9 @@ class TUIRenderer:
|
||||
self._panels.append(registry[panel_name]())
|
||||
|
||||
def _init_parser_chain(self) -> None:
|
||||
# Ensure built-in parsers are imported so @register_parser decorators fire
|
||||
import axolotl.tui.parsers # noqa: F401
|
||||
|
||||
self._parser_chain = ParserChain()
|
||||
# Register all built-in parsers
|
||||
for parser_cls in get_registered_parsers():
|
||||
@@ -65,6 +66,8 @@ class TUIRenderer:
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"custom_parser", file_path
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load spec for {file_path}")
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
parser_cls = getattr(mod, class_name)
|
||||
@@ -167,10 +170,7 @@ class TUIRenderer:
|
||||
)
|
||||
now = time.time()
|
||||
self._state.elapsed_seconds = now - self._state.start_time.timestamp()
|
||||
if (
|
||||
self._state.current_step > 0
|
||||
and self._state.total_steps > 0
|
||||
):
|
||||
if self._state.current_step > 0 and self._state.total_steps > 0:
|
||||
rate = self._state.elapsed_seconds / self._state.current_step
|
||||
remaining = self._state.total_steps - self._state.current_step
|
||||
self._state.eta_seconds = rate * remaining
|
||||
@@ -210,6 +210,8 @@ class TUIRenderer:
|
||||
self._state.total_steps = event["total_steps"]
|
||||
if "total_epochs" in event:
|
||||
self._state.total_epochs = event["total_epochs"]
|
||||
if "zero_stage" in event:
|
||||
self._state.zero_stage = event["zero_stage"]
|
||||
|
||||
elif event_type == "done":
|
||||
self._stop_event.set()
|
||||
@@ -246,6 +248,7 @@ class TUIRenderer:
|
||||
self._init_parser_chain()
|
||||
|
||||
# Set up I/O capture
|
||||
assert self._parser_chain is not None, "_init_parser_chain must be called first"
|
||||
self._io_capture = IOCapture(
|
||||
log_path=self._config.stdout_log_path,
|
||||
parser_chain=self._parser_chain,
|
||||
@@ -257,29 +260,31 @@ class TUIRenderer:
|
||||
# ensures all progress events appear in the Events panel.
|
||||
self._install_tqdm_hook()
|
||||
|
||||
self._io_capture_ready = threading.Event()
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._io_capture_ready.wait(timeout=5.0)
|
||||
|
||||
def _install_tqdm_hook(self) -> None:
|
||||
"""Replace tqdm's display method to route updates through TUI queue."""
|
||||
try:
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
import io
|
||||
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
|
||||
q = self._queue
|
||||
parser = self._tqdm_parser = None
|
||||
self._tqdm_parser = None
|
||||
# Find our tqdm parser in the chain
|
||||
for p in (self._parser_chain._parsers if self._parser_chain else []):
|
||||
for p in self._parser_chain._parsers if self._parser_chain else []:
|
||||
if p.name == "tqdm":
|
||||
self._tqdm_parser = p
|
||||
break
|
||||
|
||||
# Save originals for restore
|
||||
self._orig_tqdm_init = tqdm.tqdm.__init__
|
||||
self._orig_tqdm_class = tqdm.auto.tqdm
|
||||
|
||||
renderer_self = self
|
||||
self._orig_tqdm_class_auto = tqdm.auto.tqdm
|
||||
self._orig_tqdm_class_tqdm = tqdm.tqdm
|
||||
self._orig_tqdm_class_std = tqdm.std.tqdm
|
||||
|
||||
class TUITqdm(tqdm.tqdm):
|
||||
"""tqdm subclass that sends progress to TUI instead of terminal."""
|
||||
@@ -296,31 +301,36 @@ class TUIRenderer:
|
||||
if self.total and self.total > 0:
|
||||
pct = self.n / self.total * 100
|
||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||
elapsed = self.format_interval(self.elapsed) if hasattr(self, 'elapsed') and self.elapsed else "?"
|
||||
rate = self.format_sizeof(1.0 / self.avg_time) if hasattr(self, 'avg_time') and self.avg_time else "?"
|
||||
|
||||
# Emit events at milestones or at low frequency
|
||||
is_milestone = (
|
||||
self.n == 0 or
|
||||
self.n >= self.total or
|
||||
int(pct) % 25 == 0
|
||||
self.n == 0 or self.n >= self.total or int(pct) % 25 == 0
|
||||
)
|
||||
if is_milestone:
|
||||
try:
|
||||
q.put_nowait({
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})" if desc else f"{pct:.0f}% ({self.n}/{self.total})",
|
||||
})
|
||||
"message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})"
|
||||
if desc
|
||||
else f"{pct:.0f}% ({self.n}/{self.total})",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
metric_key = f"progress/{desc.lower().replace(' ', '_')}" if desc else "progress/unknown"
|
||||
q.put_nowait({
|
||||
metric_key = (
|
||||
f"progress/{desc.lower().replace(' ', '_')}"
|
||||
if desc
|
||||
else "progress/unknown"
|
||||
)
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "metrics",
|
||||
"logs": {metric_key: pct / 100.0},
|
||||
})
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -329,11 +339,15 @@ class TUIRenderer:
|
||||
if self.total and self.total > 0 and self.n > 0:
|
||||
desc = self.desc.rstrip(": ") if self.desc else ""
|
||||
try:
|
||||
q.put_nowait({
|
||||
q.put_nowait(
|
||||
{
|
||||
"type": "log_line",
|
||||
"level": "info",
|
||||
"message": f"[{desc}] 100% ({self.total}/{self.total}) done" if desc else f"100% ({self.total}/{self.total}) done",
|
||||
})
|
||||
"message": f"[{desc}] 100% ({self.total}/{self.total}) done"
|
||||
if desc
|
||||
else f"100% ({self.total}/{self.total}) done",
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
super().close()
|
||||
@@ -354,11 +368,12 @@ class TUIRenderer:
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
|
||||
if hasattr(self, "_orig_tqdm_class"):
|
||||
tqdm.auto.tqdm = self._orig_tqdm_class
|
||||
if hasattr(self, "_orig_tqdm_init"):
|
||||
tqdm.tqdm.__init__ = self._orig_tqdm_init
|
||||
tqdm.std.tqdm = tqdm.tqdm
|
||||
if hasattr(self, "_orig_tqdm_class_auto"):
|
||||
tqdm.auto.tqdm = self._orig_tqdm_class_auto
|
||||
if hasattr(self, "_orig_tqdm_class_tqdm"):
|
||||
tqdm.tqdm = self._orig_tqdm_class_tqdm
|
||||
if hasattr(self, "_orig_tqdm_class_std"):
|
||||
tqdm.std.tqdm = self._orig_tqdm_class_std
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -388,6 +403,10 @@ class TUIRenderer:
|
||||
if self._io_capture:
|
||||
self._io_capture.start()
|
||||
|
||||
# Signal that IO capture is live so start() can return
|
||||
if hasattr(self, "_io_capture_ready"):
|
||||
self._io_capture_ready.set()
|
||||
|
||||
try:
|
||||
with Live(
|
||||
layout,
|
||||
|
||||
@@ -48,7 +48,7 @@ class TUIState:
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
current_epoch: float = 0.0
|
||||
total_epochs: int = 1
|
||||
total_epochs: float = 1.0
|
||||
elapsed_seconds: float = 0.0
|
||||
eta_seconds: float | None = None
|
||||
|
||||
@@ -81,5 +81,8 @@ class TUIState:
|
||||
# Loss history for sparkline
|
||||
loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50))
|
||||
|
||||
# DeepSpeed zero stage (None if not using DeepSpeed)
|
||||
zero_stage: int | None = None
|
||||
|
||||
# Arbitrary plugin state
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@@ -13,6 +13,7 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
from axolotl.utils.datasets import get_default_process_count
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
@@ -52,8 +53,6 @@ from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.validation import ValidationMixin
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
from axolotl.tui.config import TUIConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user