From 4065bc14c616e12c4da037c01de8a0defd9e7c10 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 13:27:03 -0400 Subject: [PATCH] Debug log, logging improvements (#3159) * simplify logging * remove comment * progress on debug.log * add debug-level logger for file log * simplify * case insensitivity; 3rd party logging improvements * simplify * fix * tests * lint * nits * nit * Update tests/test_utils_tee.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * cleanup / comments * fix * oops --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- .gitignore | 3 + .pre-commit-config.yaml | 2 +- .../colab-axolotl-example.ipynb | 6 +- pyproject.toml | 2 +- src/axolotl/cli/__init__.py | 4 +- src/axolotl/cli/config.py | 9 +- src/axolotl/cli/inference.py | 2 - src/axolotl/cli/main.py | 4 +- src/axolotl/cli/train.py | 1 - src/axolotl/cli/utils/diffusion.py | 1 - src/axolotl/logging_config.py | 75 +++++--- .../transformers/trainer_loss_calc.py | 6 +- src/axolotl/train.py | 3 +- src/axolotl/utils/__init__.py | 9 - src/axolotl/utils/logging.py | 7 +- src/axolotl/utils/tee.py | 166 ++++++++++++++++++ src/axolotl/utils/train.py | 4 +- src/axolotl/utils/trainer.py | 9 - tests/test_logging_config_file_capture.py | 103 +++++++++++ tests/test_utils_tee.py | 107 +++++++++++ 20 files changed, 454 insertions(+), 69 deletions(-) create mode 100644 src/axolotl/utils/tee.py create mode 100644 tests/test_logging_config_file_capture.py create mode 100644 tests/test_utils_tee.py diff --git a/.gitignore b/.gitignore index 40084b408..b75becc7c 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,6 @@ out/ # vim *.swp + +# scm auto-versioning +src/axolotl/_version.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9c80898ff..92ddc7f41 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: rev: v0.12.12 hooks: - id: ruff - args: [--fix, --select, I] + args: [--fix] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.17.1 diff --git a/examples/colab-notebooks/colab-axolotl-example.ipynb b/examples/colab-notebooks/colab-axolotl-example.ipynb index 0e6ba984e..774b78b82 100644 --- a/examples/colab-notebooks/colab-axolotl-example.ipynb +++ b/examples/colab-notebooks/colab-axolotl-example.ipynb @@ -251,10 +251,10 @@ }, "outputs": [], "source": [ - "from axolotl.utils import patch_optimized_env\n", + "from axolotl.utils import set_pytorch_cuda_alloc_conf\n", "\n", - "# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n", - "patch_optimized_env()" + "# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n", + "set_pytorch_cuda_alloc_conf()" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 932219d9e..4213bc963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ line-length = 88 target-version = "py310" [tool.ruff.lint] -select = ["E", "F", "W", "C90", "B"] +select = ["E", "F", "W", "C90", "B", "I"] ignore = [ "E203", # Whitespace before ':' "E501", # Line too long diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 8955eca3e..fa647be65 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -4,5 +4,7 @@ import os from axolotl.logging_config import configure_logging -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") +os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") + configure_logging() diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 20e341a0b..93ac6147d 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -23,7 +23,8 @@ from axolotl.utils.config import ( from axolotl.utils.dict import DictDefault from axolotl.utils.logging import get_logger from axolotl.utils.mlflow_ import setup_mlflow_env_vars -from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env +from axolotl.utils.tee import prepare_debug_log +from axolotl.utils.trainer import prepare_optim_env from axolotl.utils.wandb_ import setup_wandb_env_vars LOG = get_logger(__name__) @@ -227,8 +228,11 @@ def load_cfg( }, ) + # NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we + # have to wait for cfg.output to be resolved. We could call this earlier if we write + # to a temporary file, and then move it later. + prepare_debug_log(cfg) prepare_optim_env(cfg) - prepare_opinionated_env(cfg) normalize_config(cfg) normalize_cfg_datasets(cfg) setup_wandb_env_vars(cfg) @@ -241,7 +245,6 @@ def load_cfg( for k, v in cfg.items() if v is not None } - LOG.info( "config:\n%s", json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True), diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 30d407713..3e1c01520 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils.diffusion import ( diffusion_inference, launch_diffusion_gradio_ui, - render_html, - run_diffusion, ) from axolotl.integrations.base import PluginManager from axolotl.utils.chat_templates import get_chat_template_from_config diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index acfa81389..dc6cca489 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -26,7 +26,7 @@ from axolotl.cli.utils import ( launch_training, ) from axolotl.integrations.lm_eval.cli import lm_eval -from axolotl.utils import patch_optimized_env +from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.logging import get_logger from axolotl.utils.schemas.config import AxolotlInputConfig @@ -44,7 +44,7 @@ def cli(): """Axolotl CLI - Train and fine-tune large language models""" print_axolotl_text_art() load_dotenv() - patch_optimized_env() + set_pytorch_cuda_alloc_conf() @cli.command() diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 8d33c0b84..2332717e7 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -60,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): config: Path to `axolotl` config YAML file. kwargs: Additional keyword arguments to override config file values. """ - parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser(TrainerCliArgs) parsed_cli_args, _ = parser.parse_args_into_dataclasses( diff --git a/src/axolotl/cli/utils/diffusion.py b/src/axolotl/cli/utils/diffusion.py index f83d9077b..1157bfd66 100644 --- a/src/axolotl/cli/utils/diffusion.py +++ b/src/axolotl/cli/utils/diffusion.py @@ -3,7 +3,6 @@ from __future__ import annotations import gradio as gr -import torch from colorama import Fore, Style from axolotl.integrations.diffusion import generate, resolve_mask_token_id diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index 10c5ae9dc..67b1d32f1 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -1,10 +1,7 @@ -""" -Common logging module for axolotl -""" +"""Common logging module for axolotl.""" import logging import os -import sys from logging import Formatter, Logger, LogRecord from logging.config import dictConfig from typing import Any, Dict @@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING" class AxolotlOrWarnErrorFilter(logging.Filter): """ - Allows ANY WARNING or higher (unless overridden by LOG_LEVEL) - Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL) - Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default) + Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at + INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records + (i.e. non-axolotl.INFO, DEBUG, etc. by default). """ def __init__(self, **kwargs): @@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter): class AxolotlLogger(Logger): - """A Logger that automatically rejects non-axolotl INFOs.""" + """Logger that applies filtering to non-axolotl loggers.""" def __init__(self, name: str, level: int = logging.NOTSET): super().__init__(name, level) - - # set global filter on the logger itself - self.addFilter(AxolotlOrWarnErrorFilter()) + if not name.startswith("axolotl"): + self.addFilter(AxolotlOrWarnErrorFilter()) class ColorfulFormatter(Formatter): @@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter): def format(self, record): record.rank = int(os.getenv("LOCAL_RANK", "0")) + record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else "" log_message = super().format(record) return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET @@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { }, "colorful": { "()": ColorfulFormatter, - "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", + "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s", + }, + "concise": { + "format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s", + }, + "concise_color": { + "()": ColorfulFormatter, + "format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s", + }, + }, + "filters": { + "ax_or_warn": { + "()": "axolotl.logging_config.AxolotlOrWarnErrorFilter", }, }, - "filters": {}, "handlers": { "console": { "class": "logging.StreamHandler", - "formatter": "simple", - "filters": [], - "stream": sys.stdout, + "formatter": "concise", + "filters": ["ax_or_warn"], + "stream": "ext://sys.stdout", }, "color_console": { "class": "logging.StreamHandler", - "formatter": "colorful", - "filters": [], - "stream": sys.stdout, + "formatter": "concise_color", + "filters": ["ax_or_warn"], + "stream": "ext://sys.stdout", + }, + "ax_file_only": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://axolotl.utils.tee.file_only_stream", + }, + "root_file_only": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://axolotl.utils.tee.file_only_stream", }, }, - # log level will be superseded by the AxolotlLogger "root": { - "handlers": ["console"], - "level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL), + "handlers": ["console", "root_file_only"], + "level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(), }, "loggers": { "axolotl": { - "handlers": ["color_console"], + "handlers": ["color_console", "ax_file_only"], "level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(), "propagate": False, }, @@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { def configure_logging(): """Configure with default logging""" init() # Initialize colorama + dictConfig(DEFAULT_LOGGING_CONFIG) logging.setLoggerClass(AxolotlLogger) - # set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set + # Route Python warnings through logging so they reach file handlers + logging.captureWarnings(True) + + # Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set if "ACCELERATE_LOG_LEVEL" not in os.environ: - os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL) + os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv( + "LOG_LEVEL", DEFAULT_LOG_LEVEL + ).upper() diff --git a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py index c9b968d71..b8172bbe6 100644 --- a/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py +++ b/src/axolotl/monkeypatch/transformers/trainer_loss_calc.py @@ -41,7 +41,7 @@ def patch_evaluation_loop(): """Patch the evaluation_loop method.""" # Check if already patched if hasattr(Trainer, "_original_evaluation_loop"): - LOG.info("Trainer.evaluation_loop already patched") + LOG.debug("Trainer.evaluation_loop already patched") return # Check if the patterns exist @@ -84,7 +84,7 @@ def patch_evaluation_loop(): ) exec(evaluation_loop_source, globals()) - LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation") + LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation") Trainer.evaluation_loop = axolotl_evaluation_loop @@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate(): ) exec(maybe_log_source, globals()) - LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation") + LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation") Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate diff --git a/src/axolotl/train.py b/src/axolotl/train.py index b0482bb1e..2a70d9712 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -196,10 +196,11 @@ def execute_training( ) ) - LOG.info("Starting trainer...") # TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers # if cfg.bf16: # torch.set_default_dtype(torch.bfloat16) + + LOG.info("Starting trainer...") trainer.train(resume_from_checkpoint=resume_from_checkpoint) plugin_manager = PluginManager.get_instance() diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index e5050116a..7256a5700 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf(): ) -def patch_optimized_env(): - """ - Patch environment variables to improve VRAM usage and increase download speed - """ - if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - set_pytorch_cuda_alloc_conf() - - def get_not_null(value, default=None): """ return the value if it's not None, otherwise return the default value diff --git a/src/axolotl/utils/logging.py b/src/axolotl/utils/logging.py index 7cc3530ae..35810897a 100644 --- a/src/axolotl/utils/logging.py +++ b/src/axolotl/utils/logging.py @@ -2,7 +2,6 @@ import functools import logging -import os from axolotl.utils.distributed import is_main_process @@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter): def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter: - if log_level is None: - log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None) logger = logging.getLogger(name) - if log_level is not None: - logger.setLevel(log_level.upper()) - logger.root.setLevel(log_level.upper()) + logger.setLevel(logging.DEBUG) return MultiProcessAdapter(logger, extra={}) diff --git a/src/axolotl/utils/tee.py b/src/axolotl/utils/tee.py new file mode 100644 index 000000000..1209ad1dd --- /dev/null +++ b/src/axolotl/utils/tee.py @@ -0,0 +1,166 @@ +""" +Utilities for managing the debug log file and providing a file-only stream for logging +handlers. +""" + +from __future__ import annotations + +import io +import os +import sys +import threading +from pathlib import Path +from typing import TextIO, cast + +_lock = threading.Lock() +_file_handle: io.TextIOWrapper | None = None +_log_path: str | None = None +_tee_installed: bool = False +_orig_stdout: TextIO | None = None +_orig_stderr: TextIO | None = None + + +class _FileOnlyWriter(io.TextIOBase): + """A stream-like object that writes only to the tee file. + + Before the file is prepared, writes are dropped (no-op). + """ + + def write(self, s: str) -> int: # type: ignore[override] + with _lock: + if _file_handle is not None: + _file_handle.write(s) + return len(s) + return len(s) + + def flush(self) -> None: # type: ignore[override] + with _lock: + if _file_handle is not None: + try: + _file_handle.flush() + except Exception: + pass + + +file_only_stream: io.TextIOBase = _FileOnlyWriter() + + +class _StreamTee(io.TextIOBase): + """A minimal tee that mirrors writes to the debug log file. + + Installed only after the debug log is prepared; no buffering. + """ + + def __init__(self, stream: io.TextIOBase): + self._stream = stream + + def write(self, s: str) -> int: # type: ignore[override] + with _lock: + n = self._stream.write(s) + if _file_handle is not None: + _file_handle.write(s) + return n + + def flush(self) -> None: # type: ignore[override] + with _lock: + self._stream.flush() + if _file_handle is not None: + try: + _file_handle.flush() + except Exception: + pass + + @property + def encoding(self): # type: ignore[override] + return getattr(self._stream, "encoding", None) + + @property + def errors(self): # type: ignore[override] + return getattr(self._stream, "errors", None) + + def isatty(self): # type: ignore[override] + return getattr(self._stream, "isatty", lambda: False)() + + def fileno(self): # type: ignore[override] + if hasattr(self._stream, "fileno"): + return self._stream.fileno() + raise OSError("Underlying stream has no fileno") + + +def prepare_debug_log(cfg, filename: str = "debug.log") -> str: + """ + Prepare the debug log. + + Creates the output directory, handles append/truncate logic based on cfg, and opens + the debug log file for subsequent writes via file-only handlers. + """ + global _file_handle, _log_path, _tee_installed + + with _lock: + # If already initialized, reuse existing path + if _log_path is not None: + return _log_path + + output_dir = cfg.output_dir + os.makedirs(output_dir, exist_ok=True) + + log_path = Path(output_dir) / filename + append = bool( + cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints") + ) + + if not append and log_path.exists(): + log_path.unlink() + + fh = open(log_path, "a", encoding="utf-8") + fh.flush() + + _file_handle = fh + _log_path = str(log_path) + + # Install a tee so stdout/stderr are mirrored to the debug file + # Allow disabling via env for testing or advanced usage. + tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in { + "0", + "false", + "no", + } + if tee_enabled and not _tee_installed: + # Save originals so we can restore later (e.g., tests) + global _orig_stdout, _orig_stderr + _orig_stdout = sys.stdout + _orig_stderr = sys.stderr + sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout)) + sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr)) + _tee_installed = True + + return _log_path + + +def close_debug_log() -> None: + """Flush and close the debug log and uninstall the stdout/stderr tee. + + Safe to call even if not initialized. + """ + global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr + with _lock: + # Restore original stdout/stderr if we installed a tee + if _tee_installed: + if _orig_stdout is not None: + sys.stdout = _orig_stdout + if _orig_stderr is not None: + sys.stderr = _orig_stderr + _tee_installed = False + _orig_stdout = None + _orig_stderr = None + + # Close the file handle if open + if _file_handle is not None: + try: + _file_handle.flush() + _file_handle.close() + except Exception: + pass + finally: + _file_handle = None + _log_path = None diff --git a/src/axolotl/utils/train.py b/src/axolotl/utils/train.py index 1393459d9..ad3f72be4 100644 --- a/src/axolotl/utils/train.py +++ b/src/axolotl/utils/train.py @@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No if checkpoints: last_checkpoint = str(checkpoints[-1]) if not update: + LOG.info(f"Resuming from last checkpoint at {last_checkpoint}") return last_checkpoint if ( @@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No ): cfg.resume_from_checkpoint = last_checkpoint LOG.info( - f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" + "Using auto-resume functionality to resume from checkpoint at " + f"{cfg.resume_from_checkpoint}" ) return cfg.resume_from_checkpoint diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index a0f4fd567..662a54655 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -655,15 +655,6 @@ def prepare_optim_env(cfg): os.environ["ACCELERATE_MIXED_PRECISION"] = "no" -def prepare_opinionated_env(cfg): - if cfg.qlora_sharded_model_loading: - # model loading is forked after the tokenizer - os.environ["TOKENIZERS_PARALLELISM"] = "false" - if cfg.sample_packing: - # multipack parallel packing sampler defaults to using fork - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - def setup_trainer( cfg, train_dataset, diff --git a/tests/test_logging_config_file_capture.py b/tests/test_logging_config_file_capture.py new file mode 100644 index 000000000..44b0ee5e6 --- /dev/null +++ b/tests/test_logging_config_file_capture.py @@ -0,0 +1,103 @@ +import logging +import tempfile + +import pytest + + +def read(path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + +@pytest.fixture(autouse=True) +def _reset_logging_state(): + # Ensure a clean slate for logging between tests + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.shutdown() + # Note: dictConfig in configure_logging will set up handlers again + yield + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + logging.shutdown() + + +def test_axolotl_logs_captured_at_all_levels(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + from axolotl.utils.logging import get_logger + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test to simplify interaction with pytest capture + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + path = tee.prepare_debug_log( + type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + ) + + log = get_logger("axolotl.test") + log.info("AX-INFO") + log.debug("AX-DEBUG") + tee.file_only_stream.flush() + + data = read(path) + assert "AX-INFO" in data + assert "AX-DEBUG" in data + tee.close_debug_log() + + +def test_third_party_logs_filtered_and_warning_captured(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + path = tee.prepare_debug_log( + type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + ) + + # Third-party logger (non-axolotl) + other = logging.getLogger("thirdparty.lib") + other.info("TP-INFO") + other.warning("TP-WARN") + + # Simulate Python warnings routed through logging + logging.getLogger("py.warnings").warning("PY-WARN") + + # Push through buffers + tee.file_only_stream.flush() + + data = read(path) + # INFO from non-axolotl should be filtered out (not present) + assert "TP-INFO" not in data + # WARNING+ should be present + assert "TP-WARN" in data + # Python warnings captured (via py.warnings logger) + assert "PY-WARN" in data + tee.close_debug_log() + tee.close_debug_log() + + +def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch): + from axolotl.logging_config import configure_logging + from axolotl.utils import tee + from axolotl.utils.logging import get_logger + + with tempfile.TemporaryDirectory() as td: + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + configure_logging() + cfg = type("Cfg", (), {"output_dir": td, "get": lambda *_: False}) + p1 = tee.prepare_debug_log(cfg) + p2 = tee.prepare_debug_log(cfg) + assert p1 == p2 + + log = get_logger("axolotl.test") + marker = "UNIQUE-MARKER-12345" + log.info(marker) + tee.file_only_stream.flush() + + data = read(p1) + # Ensure the marker appears once (not duplicated via propagation) + assert data.count(marker) == 1 + tee.close_debug_log() diff --git a/tests/test_utils_tee.py b/tests/test_utils_tee.py new file mode 100644 index 000000000..e2c153667 --- /dev/null +++ b/tests/test_utils_tee.py @@ -0,0 +1,107 @@ +import os +import tempfile + + +def _dummy_cfg(output_dir: str, append: bool = False): + # Minimal object with attributes used by prepare_debug_log + class Cfg: + def __init__(self, out, append): + self.output_dir = out + self._append = append + + def get(self, key, default=None): + if key in {"resume_from_checkpoint", "auto_resume_from_checkpoints"}: + return self._append + return default + + return Cfg(output_dir, append) + + +def read(path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + +def test_file_only_stream_writes_after_prepare(monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + cfg = _dummy_cfg(td, append=False) + + # before prepare: writing to file_only_stream creates no file + tee.file_only_stream.write("before\n") + tee.file_only_stream.flush() + assert not os.path.exists(os.path.join(td, "debug.log")) + + # prepare and write + path = tee.prepare_debug_log(cfg) + assert os.path.basename(path) == "debug.log" + tee.file_only_stream.write("hello\n") + tee.file_only_stream.flush() + + content = read(path) + assert "hello" in content + + tee.close_debug_log() + + +def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + cfg = _dummy_cfg(td, append=False) + try: + # Install tee while capture is disabled so stdout tee wraps real stdout. + with capsys.disabled(): + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "1") + path = tee.prepare_debug_log(cfg) + import sys + + print("printed-line") + sys.stdout.flush() + + # Now verify file contains the line + content = read(path) + assert "printed-line" in content + finally: + tee.close_debug_log() + + +def test_truncate_vs_append_behavior(monkeypatch): + from axolotl.utils import tee + + with tempfile.TemporaryDirectory() as td: + # Avoid stdout tee in this test + monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0") + # First run creates file with A + cfg = _dummy_cfg(td, append=False) + _ = tee.prepare_debug_log(cfg) + try: + tee.file_only_stream.write("A\n") + tee.file_only_stream.flush() + finally: + tee.close_debug_log() + + # Second run with append=False truncates + cfg2 = _dummy_cfg(td, append=False) + path2 = tee.prepare_debug_log(cfg2) + try: + tee.file_only_stream.write("B\n") + tee.file_only_stream.flush() + content = read(path2) + assert "A\n" not in content and "B\n" in content + finally: + tee.close_debug_log() + + # Third run with append=True preserves existing + cfg3 = _dummy_cfg(td, append=True) + path3 = tee.prepare_debug_log(cfg3) + try: + tee.file_only_stream.write("C\n") + tee.file_only_stream.flush() + content = read(path3) + assert "B\n" in content and "C\n" in content + finally: + tee.close_debug_log()