diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index ac7354cd9..648bdf8f7 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -2,6 +2,7 @@ import gc import os +import queue from pathlib import Path from typing import Union @@ -36,7 +37,7 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): # Start TUI early (before data loading) so it captures preprocessing events tui_renderer = None - tui_queue = None + tui_queue: queue.Queue | None = None is_rank_0 = int(os.getenv("LOCAL_RANK", "0")) == 0 if is_rank_0: from axolotl.train import _is_tui_enabled @@ -44,12 +45,16 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): if _is_tui_enabled(cfg): import queue as _queue + from axolotl.train import _get_tui_config from axolotl.tui.config import TUIConfig from axolotl.tui.renderer import TUIRenderer - from axolotl.train import _get_tui_config tui_config_dict = _get_tui_config(cfg) - tui_config = TUIConfig(**tui_config_dict) if isinstance(tui_config_dict, dict) else tui_config_dict + tui_config = ( + TUIConfig(**tui_config_dict) + if isinstance(tui_config_dict, dict) + else tui_config_dict + ) tui_queue = _queue.Queue(maxsize=4096) tui_renderer = TUIRenderer(config=tui_config, metric_queue=tui_queue) @@ -58,12 +63,14 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): training_mode = str(cfg.rl) if cfg.rl else "sft" world_size = int(os.environ.get("WORLD_SIZE", 1)) try: - tui_queue.put_nowait({ - "type": "run_info", - "model_name": model_name, - "training_mode": training_mode, - "world_size": world_size, - }) + tui_queue.put_nowait( + { + "type": "run_info", + "model_name": model_name, + "training_mode": training_mode, + "world_size": world_size, + } + ) except _queue.Full: pass @@ -74,7 +81,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): from axolotl.tui.callback import _TUILogHandler - _early_log_handler = _TUILogHandler(tui_queue, min_level=tui_config.log_level) + _early_log_handler = _TUILogHandler( + tui_queue, min_level=tui_config.log_level + ) _early_log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s")) # Attach to BOTH root and axolotl loggers because axolotl logger # has propagate=False so root handler never sees axolotl.* messages @@ -110,8 +119,9 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): # (e.g., error during data loading), clean up here if tui_renderer is not None and not tui_renderer._stop_event.is_set(): try: - tui_queue.put_nowait({"type": "done"}) - except Exception: + if tui_queue is not None: + tui_queue.put_nowait({"type": "done"}) + except queue.Full: pass tui_renderer.stop() # Remove early log handler from both root and axolotl loggers diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 85629bca2..108dfd140 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -9,7 +9,6 @@ import os import shutil import signal import sys -import typing import weakref from collections import OrderedDict from contextlib import ExitStack @@ -42,9 +41,6 @@ from axolotl.utils.schemas.enums import RLType from axolotl.utils.train import determine_last_checkpoint from axolotl.utils.trainer import setup_trainer -if typing.TYPE_CHECKING: - from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder - LOG = get_logger(__name__) TELEMETRY_MANAGER = TelemetryManager.get_instance() @@ -487,7 +483,7 @@ def handle_untrained_tokens_fix( def setup_model_and_trainer( cfg: DictDefault, dataset_meta: TrainDatasetMeta ) -> tuple[ - "HFRLTrainerBuilder" | "HFCausalTrainerBuilder", + Trainer, PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, @@ -614,7 +610,9 @@ def train( from axolotl.tui.config import TUIConfig tui_config = _get_tui_config(cfg) - tui_config_obj = TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config + tui_config_obj = ( + TUIConfig(**tui_config) if isinstance(tui_config, dict) else tui_config + ) # Reuse the early-started renderer if available (started in do_train) early_renderer = getattr(cfg, "_tui_renderer", None) @@ -628,16 +626,12 @@ def train( tui_callback._renderer_started_early = True trainer.add_callback(tui_callback) - # Send model info to the callback - model_name = cfg.base_model or "" - training_mode = str(cfg.rl) if cfg.rl else "sft" - world_size = int(os.environ.get("WORLD_SIZE", 1)) - tui_callback._put({ - "type": "run_info", - "model_name": model_name, - "training_mode": training_mode, - "world_size": world_size, - }) + # Stash model info so on_train_begin can emit a single unified run_info event + tui_callback._pending_run_info = { + "model_name": cfg.base_model or "", + "training_mode": str(cfg.rl) if cfg.rl else "sft", + "world_size": int(os.environ.get("WORLD_SIZE", 1)), + } LOG.info("TUI dashboard enabled") # Handle untrained tokens if configured diff --git a/src/axolotl/tui/callback.py b/src/axolotl/tui/callback.py index 57db5c519..51c47da19 100644 --- a/src/axolotl/tui/callback.py +++ b/src/axolotl/tui/callback.py @@ -4,7 +4,6 @@ from __future__ import annotations import logging import queue -from typing import Any from transformers.trainer_callback import TrainerCallback @@ -33,11 +32,13 @@ class _TUILogHandler(logging.Handler): try: level = self._LEVEL_MAP.get(record.levelno, "info") msg = self.format(record) - self._queue.put_nowait({ - "type": "log_line", - "level": level, - "message": msg, - }) + self._queue.put_nowait( + { + "type": "log_line", + "level": level, + "message": msg, + } + ) except queue.Full: pass except Exception: @@ -57,6 +58,7 @@ class AxolotlTUICallback(TrainerCallback): self._renderer = TUIRenderer(config=config, metric_queue=self._queue) self._log_handler: _TUILogHandler | None = None self._renderer_started_early: bool = False + self._pending_run_info: dict | None = None def _put(self, event: dict) -> None: try: @@ -65,25 +67,27 @@ class AxolotlTUICallback(TrainerCallback): pass def on_train_begin(self, args, state, control, model=None, **kwargs): - # Send run info - run_name = getattr(args, "run_name", "") or "" - self._put( - { - "type": "run_info", - "run_name": run_name, - "total_steps": state.max_steps, - "total_epochs": int(args.num_train_epochs) if args.num_train_epochs else 1, - } - ) + # Send a single unified run_info event with all fields + run_info = { + "type": "run_info", + "run_name": getattr(args, "run_name", "") or "", + "total_steps": state.max_steps, + "total_epochs": float(args.num_train_epochs) + if args.num_train_epochs + else 1.0, + } + # Merge in model_name/training_mode/world_size if stashed by train.py + if self._pending_run_info: + run_info.update(self._pending_run_info) + self._pending_run_info = None + self._put(run_info) if not self._renderer_started_early: # Attach a logging handler to feed log messages into the events panel self._log_handler = _TUILogHandler( self._queue, min_level=self._config.log_level ) - self._log_handler.setFormatter( - logging.Formatter("[%(name)s] %(message)s") - ) + self._log_handler.setFormatter(logging.Formatter("[%(name)s] %(message)s")) # Attach to both root and axolotl loggers (axolotl has propagate=False) logging.getLogger().addHandler(self._log_handler) logging.getLogger("axolotl").addHandler(self._log_handler) diff --git a/src/axolotl/tui/config.py b/src/axolotl/tui/config.py index ba32f36d8..9d6cc22aa 100644 --- a/src/axolotl/tui/config.py +++ b/src/axolotl/tui/config.py @@ -18,15 +18,11 @@ class TUIConfig(BaseModel): ) log_level: str = Field( default="debug", - json_schema_extra={ - "description": "Minimum log level shown in events panel" - }, + json_schema_extra={"description": "Minimum log level shown in events panel"}, ) panels: list[str] = Field( default_factory=lambda: ["progress", "training", "hardware", "events", "debug"], - json_schema_extra={ - "description": "Ordered list of panels to display" - }, + json_schema_extra={"description": "Ordered list of panels to display"}, ) hardware_poll_interval: int = Field( default=2, @@ -34,13 +30,9 @@ class TUIConfig(BaseModel): ) stdout_log_path: str = Field( default="axolotl_stdout.log", - json_schema_extra={ - "description": "File path for captured stdout/stderr log" - }, + json_schema_extra={"description": "File path for captured stdout/stderr log"}, ) parser_plugins: list[str] = Field( default_factory=list, - json_schema_extra={ - "description": "List of extra parser classes to load" - }, + json_schema_extra={"description": "List of extra parser classes to load"}, ) diff --git a/src/axolotl/tui/gpu.py b/src/axolotl/tui/gpu.py index 8ea05b9ed..3fa2cfff2 100644 --- a/src/axolotl/tui/gpu.py +++ b/src/axolotl/tui/gpu.py @@ -68,5 +68,5 @@ class GPUPoller: ) ) except Exception: - pass + LOG.debug("Error polling GPU device %d", i, exc_info=True) return stats diff --git a/src/axolotl/tui/io_capture.py b/src/axolotl/tui/io_capture.py index 9ae4b080b..e497c4c72 100644 --- a/src/axolotl/tui/io_capture.py +++ b/src/axolotl/tui/io_capture.py @@ -2,17 +2,14 @@ from __future__ import annotations +import logging import os import queue -import re import sys import threading from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - pass +from typing import IO # --------------------------------------------------------------------------- # Parser registry @@ -23,7 +20,8 @@ _parser_registry: list[type[LineParser]] = [] def register_parser(cls: type[LineParser]) -> type[LineParser]: """Decorator to register a LineParser subclass.""" - _parser_registry.append(cls) + if cls not in _parser_registry: + _parser_registry.append(cls) return cls @@ -92,7 +90,7 @@ class IOCapture: self._parser_chain = parser_chain self._queue = metric_queue self._log_path = log_path - self._log_file = None + self._log_file: IO[str] | None = None self._thread: threading.Thread | None = None self._read_fd: int | None = None self._write_fd: int | None = None @@ -129,7 +127,7 @@ class IOCapture: def stop(self) -> None: # Restore fds — closes the write end, causing reader to see EOF - if self._saved_stdout_fd is not None: + if self._saved_stdout_fd is not None and self._saved_stderr_fd is not None: sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ os.dup2(self._saved_stdout_fd, 1) @@ -141,6 +139,10 @@ class IOCapture: if self._thread is not None: self._thread.join(timeout=2.0) + if self._thread.is_alive(): + logging.getLogger(__name__).warning( + "IO capture thread did not exit after 2s" + ) self._thread = None if self._log_file is not None: @@ -150,6 +152,7 @@ class IOCapture: def _drain(self) -> None: # Read raw bytes and split on both \n and \r to handle tqdm progress bars # which use \r for in-place updates without \n + assert self._read_fd is not None, "_drain called before start()" with os.fdopen(self._read_fd, "rb") as pipe: buf = b"" while True: diff --git a/src/axolotl/tui/panels/__init__.py b/src/axolotl/tui/panels/__init__.py index 46c790d37..d97012b43 100644 --- a/src/axolotl/tui/panels/__init__.py +++ b/src/axolotl/tui/panels/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any from rich.console import RenderableType @@ -50,7 +49,7 @@ class BasePanel(ABC): """Return a rich renderable. Called every tick.""" ... - def on_event(self, event: dict) -> None: + def on_event(self, event: dict) -> None: # noqa: B027 """Optional: react to raw metric events before state is merged.""" pass diff --git a/src/axolotl/tui/panels/completions.py b/src/axolotl/tui/panels/completions.py index 43379ad7b..c4a106168 100644 --- a/src/axolotl/tui/panels/completions.py +++ b/src/axolotl/tui/panels/completions.py @@ -22,6 +22,9 @@ class CompletionsPanel(BasePanel): modes = ["grpo", "dpo"] def render(self, state: TUIState) -> RenderableType: + if "*" not in self.modes and state.training_mode not in self.modes: + return Text("") + if not state.completions: return Panel( Text("No completions yet...", style="dim"), diff --git a/src/axolotl/tui/panels/debug.py b/src/axolotl/tui/panels/debug.py index ed73ae261..2e29d1a7a 100644 --- a/src/axolotl/tui/panels/debug.py +++ b/src/axolotl/tui/panels/debug.py @@ -2,15 +2,12 @@ from __future__ import annotations -from collections import deque -from datetime import datetime - from rich.console import RenderableType from rich.panel import Panel from rich.text import Text from axolotl.tui.panels import BasePanel, register_panel -from axolotl.tui.state import LogLine, TUIState +from axolotl.tui.state import TUIState @register_panel(position="bottom", weight=30) @@ -22,7 +19,9 @@ class DebugPanel(BasePanel): def render(self, state: TUIState) -> RenderableType: lines = Text() # Show last 8 debug-level log lines - debug_lines = [l for l in state.log_lines if l.level == "debug"][-8:] + debug_lines = [ + log_entry for log_entry in state.log_lines if log_entry.level == "debug" + ][-8:] for log_line in debug_lines: ts = log_line.timestamp.strftime("%H:%M:%S") lines.append(f"[{ts}] ", style="dim") diff --git a/src/axolotl/tui/panels/events.py b/src/axolotl/tui/panels/events.py index a16de25d0..1ecaf7259 100644 --- a/src/axolotl/tui/panels/events.py +++ b/src/axolotl/tui/panels/events.py @@ -27,7 +27,9 @@ class EventsPanel(BasePanel): def render(self, state: TUIState) -> RenderableType: lines = Text() # Show last 15 non-debug log lines (debug goes to DebugPanel) - recent = [l for l in state.log_lines if l.level != "debug"][-15:] + recent = [ + log_entry for log_entry in state.log_lines if log_entry.level != "debug" + ][-15:] for log_line in recent: ts = log_line.timestamp.strftime("%H:%M:%S") level = log_line.level.upper() diff --git a/src/axolotl/tui/parsers/__init__.py b/src/axolotl/tui/parsers/__init__.py index 8e4fbdb85..bc1026d12 100644 --- a/src/axolotl/tui/parsers/__init__.py +++ b/src/axolotl/tui/parsers/__init__.py @@ -3,5 +3,5 @@ from axolotl.tui.parsers.deepspeed import DeepSpeedParser # noqa: F401 from axolotl.tui.parsers.nccl import NCCLErrorParser # noqa: F401 from axolotl.tui.parsers.raw_log import RawLogParser # noqa: F401 -from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401 from axolotl.tui.parsers.torch_compile import TorchCompileParser # noqa: F401 +from axolotl.tui.parsers.tqdm import TqdmParser # noqa: F401 diff --git a/src/axolotl/tui/parsers/raw_log.py b/src/axolotl/tui/parsers/raw_log.py index 8f3f11d07..a806f7335 100644 --- a/src/axolotl/tui/parsers/raw_log.py +++ b/src/axolotl/tui/parsers/raw_log.py @@ -14,8 +14,8 @@ class RawLogParser(LineParser): _LOG_RE = re.compile( r"^(?P\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}[,\.]\d+)" - r"\s*[-–]\s*(?PDEBUG|INFO|WARNING|ERROR|CRITICAL)" - r"\s*[-–]\s*(?P.+)$", + r"\s*[-]\s*(?PDEBUG|INFO|WARNING|ERROR|CRITICAL)" + r"\s*[-]\s*(?P.+)$", re.IGNORECASE, ) diff --git a/src/axolotl/tui/parsers/tqdm.py b/src/axolotl/tui/parsers/tqdm.py index 2ad30ae04..278452c54 100644 --- a/src/axolotl/tui/parsers/tqdm.py +++ b/src/axolotl/tui/parsers/tqdm.py @@ -25,9 +25,7 @@ class TqdmParser(LineParser): # Also match simpler forms like: # Fetching 0 files: 0it [00:00, ?it/s] - _FETCH_RE = re.compile( - r"(?P[\w\s]+):\s*(?P\d+)(?:it)?\s*\[.*?\]" - ) + _FETCH_RE = re.compile(r"(?P[\w\s]+):\s*(?P\d+)(?:it)?\s*\[.*?\]") def parse(self, line: str, source: str) -> list[dict]: m = self._TQDM_RE.search(line) @@ -41,21 +39,48 @@ class TqdmParser(LineParser): # Surface as a log line with progress info if pct == 100 or pct == 0 or pct % 25 == 0: - msg = f"[{desc}] {pct}% ({current}/{total})" if desc else f"{pct}% ({current}/{total})" - events.append({ - "type": "log_line", - "level": "info", - "message": msg, - }) + msg = ( + f"[{desc}] {pct}% ({current}/{total})" + if desc + else f"{pct}% ({current}/{total})" + ) + events.append( + { + "type": "log_line", + "level": "info", + "message": msg, + } + ) # Also emit as a progress metric - events.append({ - "type": "metrics", - "logs": { - f"progress/{desc.lower().replace(' ', '_')}": pct / 100.0, - }, - }) + cleaned_desc = desc.strip().lower().replace(" ", "_") + if not cleaned_desc: + cleaned_desc = "progress" + events.append( + { + "type": "metrics", + "logs": { + f"progress/{cleaned_desc}": pct / 100.0, + }, + } + ) return events + # Fallback: try simpler fetch-style progress lines + m = self._FETCH_RE.search(line) + if m: + desc = m.group("desc").strip().rstrip(":") + current = int(m.group("current")) + cleaned_desc = desc.strip().lower().replace(" ", "_") + if not cleaned_desc: + cleaned_desc = "fetch" + return [ + { + "type": "log_line", + "level": "info", + "message": f"[{desc}] {current}" if desc else f"{current}", + } + ] + return [] diff --git a/src/axolotl/tui/renderer.py b/src/axolotl/tui/renderer.py index 8384ce75a..ceec15776 100644 --- a/src/axolotl/tui/renderer.py +++ b/src/axolotl/tui/renderer.py @@ -2,7 +2,6 @@ from __future__ import annotations -import importlib import logging import queue import threading @@ -13,7 +12,6 @@ from typing import Any from rich.console import Console from rich.layout import Layout from rich.live import Live -from rich.text import Text from axolotl.tui.config import TUIConfig from axolotl.tui.gpu import GPUPoller @@ -49,6 +47,9 @@ class TUIRenderer: self._panels.append(registry[panel_name]()) def _init_parser_chain(self) -> None: + # Ensure built-in parsers are imported so @register_parser decorators fire + import axolotl.tui.parsers # noqa: F401 + self._parser_chain = ParserChain() # Register all built-in parsers for parser_cls in get_registered_parsers(): @@ -65,6 +66,8 @@ class TUIRenderer: spec = importlib.util.spec_from_file_location( "custom_parser", file_path ) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load spec for {file_path}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) parser_cls = getattr(mod, class_name) @@ -167,10 +170,7 @@ class TUIRenderer: ) now = time.time() self._state.elapsed_seconds = now - self._state.start_time.timestamp() - if ( - self._state.current_step > 0 - and self._state.total_steps > 0 - ): + if self._state.current_step > 0 and self._state.total_steps > 0: rate = self._state.elapsed_seconds / self._state.current_step remaining = self._state.total_steps - self._state.current_step self._state.eta_seconds = rate * remaining @@ -210,6 +210,8 @@ class TUIRenderer: self._state.total_steps = event["total_steps"] if "total_epochs" in event: self._state.total_epochs = event["total_epochs"] + if "zero_stage" in event: + self._state.zero_stage = event["zero_stage"] elif event_type == "done": self._stop_event.set() @@ -246,6 +248,7 @@ class TUIRenderer: self._init_parser_chain() # Set up I/O capture + assert self._parser_chain is not None, "_init_parser_chain must be called first" self._io_capture = IOCapture( log_path=self._config.stdout_log_path, parser_chain=self._parser_chain, @@ -257,29 +260,31 @@ class TUIRenderer: # ensures all progress events appear in the Events panel. self._install_tqdm_hook() + self._io_capture_ready = threading.Event() self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() + self._io_capture_ready.wait(timeout=5.0) def _install_tqdm_hook(self) -> None: """Replace tqdm's display method to route updates through TUI queue.""" try: - import tqdm - import tqdm.auto import io + import tqdm + import tqdm.auto + q = self._queue - parser = self._tqdm_parser = None + self._tqdm_parser = None # Find our tqdm parser in the chain - for p in (self._parser_chain._parsers if self._parser_chain else []): + for p in self._parser_chain._parsers if self._parser_chain else []: if p.name == "tqdm": self._tqdm_parser = p break # Save originals for restore - self._orig_tqdm_init = tqdm.tqdm.__init__ - self._orig_tqdm_class = tqdm.auto.tqdm - - renderer_self = self + self._orig_tqdm_class_auto = tqdm.auto.tqdm + self._orig_tqdm_class_tqdm = tqdm.tqdm + self._orig_tqdm_class_std = tqdm.std.tqdm class TUITqdm(tqdm.tqdm): """tqdm subclass that sends progress to TUI instead of terminal.""" @@ -296,31 +301,36 @@ class TUIRenderer: if self.total and self.total > 0: pct = self.n / self.total * 100 desc = self.desc.rstrip(": ") if self.desc else "" - elapsed = self.format_interval(self.elapsed) if hasattr(self, 'elapsed') and self.elapsed else "?" - rate = self.format_sizeof(1.0 / self.avg_time) if hasattr(self, 'avg_time') and self.avg_time else "?" - # Emit events at milestones or at low frequency is_milestone = ( - self.n == 0 or - self.n >= self.total or - int(pct) % 25 == 0 + self.n == 0 or self.n >= self.total or int(pct) % 25 == 0 ) if is_milestone: try: - q.put_nowait({ - "type": "log_line", - "level": "info", - "message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})" if desc else f"{pct:.0f}% ({self.n}/{self.total})", - }) + q.put_nowait( + { + "type": "log_line", + "level": "info", + "message": f"[{desc}] {pct:.0f}% ({self.n}/{self.total})" + if desc + else f"{pct:.0f}% ({self.n}/{self.total})", + } + ) except Exception: pass try: - metric_key = f"progress/{desc.lower().replace(' ', '_')}" if desc else "progress/unknown" - q.put_nowait({ - "type": "metrics", - "logs": {metric_key: pct / 100.0}, - }) + metric_key = ( + f"progress/{desc.lower().replace(' ', '_')}" + if desc + else "progress/unknown" + ) + q.put_nowait( + { + "type": "metrics", + "logs": {metric_key: pct / 100.0}, + } + ) except Exception: pass @@ -329,11 +339,15 @@ class TUIRenderer: if self.total and self.total > 0 and self.n > 0: desc = self.desc.rstrip(": ") if self.desc else "" try: - q.put_nowait({ - "type": "log_line", - "level": "info", - "message": f"[{desc}] 100% ({self.total}/{self.total}) done" if desc else f"100% ({self.total}/{self.total}) done", - }) + q.put_nowait( + { + "type": "log_line", + "level": "info", + "message": f"[{desc}] 100% ({self.total}/{self.total}) done" + if desc + else f"100% ({self.total}/{self.total}) done", + } + ) except Exception: pass super().close() @@ -354,11 +368,12 @@ class TUIRenderer: import tqdm import tqdm.auto - if hasattr(self, "_orig_tqdm_class"): - tqdm.auto.tqdm = self._orig_tqdm_class - if hasattr(self, "_orig_tqdm_init"): - tqdm.tqdm.__init__ = self._orig_tqdm_init - tqdm.std.tqdm = tqdm.tqdm + if hasattr(self, "_orig_tqdm_class_auto"): + tqdm.auto.tqdm = self._orig_tqdm_class_auto + if hasattr(self, "_orig_tqdm_class_tqdm"): + tqdm.tqdm = self._orig_tqdm_class_tqdm + if hasattr(self, "_orig_tqdm_class_std"): + tqdm.std.tqdm = self._orig_tqdm_class_std except Exception: pass @@ -388,6 +403,10 @@ class TUIRenderer: if self._io_capture: self._io_capture.start() + # Signal that IO capture is live so start() can return + if hasattr(self, "_io_capture_ready"): + self._io_capture_ready.set() + try: with Live( layout, diff --git a/src/axolotl/tui/state.py b/src/axolotl/tui/state.py index f67445570..db1d62f84 100644 --- a/src/axolotl/tui/state.py +++ b/src/axolotl/tui/state.py @@ -48,7 +48,7 @@ class TUIState: current_step: int = 0 total_steps: int = 0 current_epoch: float = 0.0 - total_epochs: int = 1 + total_epochs: float = 1.0 elapsed_seconds: float = 0.0 eta_seconds: float | None = None @@ -81,5 +81,8 @@ class TUIState: # Loss history for sparkline loss_history: deque[float] = field(default_factory=lambda: deque(maxlen=50)) + # DeepSpeed zero stage (None if not using DeepSpeed) + zero_stage: int | None = None + # Arbitrary plugin state extra: dict[str, Any] = field(default_factory=dict) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 930ee4c7f..81ce3f75e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -13,6 +13,7 @@ from pydantic import ( model_validator, ) +from axolotl.tui.config import TUIConfig from axolotl.utils.datasets import get_default_process_count from axolotl.utils.logging import get_logger from axolotl.utils.schemas.datasets import ( @@ -52,8 +53,6 @@ from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.validation import ValidationMixin from axolotl.utils.schemas.vllm import VllmConfig -from axolotl.tui.config import TUIConfig - LOG = get_logger(__name__)