Debug log, logging improvements (#3159)
* simplify logging * remove comment * progress on debug.log * add debug-level logger for file log * simplify * case insensitivity; 3rd party logging improvements * simplify * fix * tests * lint * nits * nit * Update tests/test_utils_tee.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * cleanup / comments * fix * oops --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -190,3 +190,6 @@ out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
# scm auto-versioning
|
||||
src/axolotl/_version.py
|
||||
|
||||
@@ -14,7 +14,7 @@ repos:
|
||||
rev: v0.12.12
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --select, I]
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.17.1
|
||||
|
||||
@@ -251,10 +251,10 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from axolotl.utils import patch_optimized_env\n",
|
||||
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||
"\n",
|
||||
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"patch_optimized_env()"
|
||||
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||
"set_pytorch_cuda_alloc_conf()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -32,7 +32,7 @@ line-length = 88
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "W", "C90", "B"]
|
||||
select = ["E", "F", "W", "C90", "B", "I"]
|
||||
ignore = [
|
||||
"E203", # Whitespace before ':'
|
||||
"E501", # Line too long
|
||||
|
||||
@@ -4,5 +4,7 @@ import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -23,7 +23,8 @@ from axolotl.utils.config import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -227,8 +228,11 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||
# to a temporary file, and then move it later.
|
||||
prepare_debug_log(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
normalize_cfg_datasets(cfg)
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -241,7 +245,6 @@ def load_cfg(
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
|
||||
@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.diffusion import (
|
||||
diffusion_inference,
|
||||
launch_diffusion_gradio_ui,
|
||||
render_html,
|
||||
run_diffusion,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -44,7 +44,7 @@ def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -60,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
Common logging module for axolotl
|
||||
"""
|
||||
"""Common logging module for axolotl."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from logging import Formatter, Logger, LogRecord
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
|
||||
|
||||
class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
"""
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
|
||||
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
|
||||
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
|
||||
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
|
||||
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
|
||||
|
||||
class AxolotlLogger(Logger):
|
||||
"""A Logger that automatically rejects non-axolotl INFOs."""
|
||||
"""Logger that applies filtering to non-axolotl loggers."""
|
||||
|
||||
def __init__(self, name: str, level: int = logging.NOTSET):
|
||||
super().__init__(name, level)
|
||||
|
||||
# set global filter on the logger itself
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
if not name.startswith("axolotl"):
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
|
||||
|
||||
class ColorfulFormatter(Formatter):
|
||||
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
|
||||
|
||||
def format(self, record):
|
||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
"concise": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
||||
},
|
||||
"concise_color": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"ax_or_warn": {
|
||||
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "simple",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"color_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "colorful",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise_color",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"ax_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
"root_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
},
|
||||
# log level will be superseded by the AxolotlLogger
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
|
||||
"handlers": ["console", "root_file_only"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
|
||||
},
|
||||
"loggers": {
|
||||
"axolotl": {
|
||||
"handlers": ["color_console"],
|
||||
"handlers": ["color_console", "ax_file_only"],
|
||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||
"propagate": False,
|
||||
},
|
||||
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
init() # Initialize colorama
|
||||
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
logging.setLoggerClass(AxolotlLogger)
|
||||
|
||||
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
# Route Python warnings through logging so they reach file handlers
|
||||
logging.captureWarnings(True)
|
||||
|
||||
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
if "ACCELERATE_LOG_LEVEL" not in os.environ:
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
|
||||
"LOG_LEVEL", DEFAULT_LOG_LEVEL
|
||||
).upper()
|
||||
|
||||
@@ -41,7 +41,7 @@ def patch_evaluation_loop():
|
||||
"""Patch the evaluation_loop method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
LOG.info("Trainer.evaluation_loop already patched")
|
||||
LOG.debug("Trainer.evaluation_loop already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
@@ -84,7 +84,7 @@ def patch_evaluation_loop():
|
||||
)
|
||||
exec(evaluation_loop_source, globals())
|
||||
|
||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
Trainer.evaluation_loop = axolotl_evaluation_loop
|
||||
|
||||
|
||||
@@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate():
|
||||
)
|
||||
exec(maybe_log_source, globals())
|
||||
|
||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate
|
||||
|
||||
@@ -196,10 +196,11 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||
# if cfg.bf16:
|
||||
# torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
@@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf():
|
||||
)
|
||||
|
||||
|
||||
def patch_optimized_env():
|
||||
"""
|
||||
Patch environment variables to improve VRAM usage and increase download speed
|
||||
"""
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
def get_not_null(value, default=None):
|
||||
"""
|
||||
return the value if it's not None, otherwise return the default value
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
@@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
|
||||
|
||||
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
||||
if log_level is None:
|
||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||
logger = logging.getLogger(name)
|
||||
if log_level is not None:
|
||||
logger.setLevel(log_level.upper())
|
||||
logger.root.setLevel(log_level.upper())
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return MultiProcessAdapter(logger, extra={})
|
||||
|
||||
166
src/axolotl/utils/tee.py
Normal file
166
src/axolotl/utils/tee.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Utilities for managing the debug log file and providing a file-only stream for logging
|
||||
handlers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import TextIO, cast
|
||||
|
||||
_lock = threading.Lock()
|
||||
_file_handle: io.TextIOWrapper | None = None
|
||||
_log_path: str | None = None
|
||||
_tee_installed: bool = False
|
||||
_orig_stdout: TextIO | None = None
|
||||
_orig_stderr: TextIO | None = None
|
||||
|
||||
|
||||
class _FileOnlyWriter(io.TextIOBase):
|
||||
"""A stream-like object that writes only to the tee file.
|
||||
|
||||
Before the file is prepared, writes are dropped (no-op).
|
||||
"""
|
||||
|
||||
def write(self, s: str) -> int: # type: ignore[override]
|
||||
with _lock:
|
||||
if _file_handle is not None:
|
||||
_file_handle.write(s)
|
||||
return len(s)
|
||||
return len(s)
|
||||
|
||||
def flush(self) -> None: # type: ignore[override]
|
||||
with _lock:
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
file_only_stream: io.TextIOBase = _FileOnlyWriter()
|
||||
|
||||
|
||||
class _StreamTee(io.TextIOBase):
|
||||
"""A minimal tee that mirrors writes to the debug log file.
|
||||
|
||||
Installed only after the debug log is prepared; no buffering.
|
||||
"""
|
||||
|
||||
def __init__(self, stream: io.TextIOBase):
|
||||
self._stream = stream
|
||||
|
||||
def write(self, s: str) -> int: # type: ignore[override]
|
||||
with _lock:
|
||||
n = self._stream.write(s)
|
||||
if _file_handle is not None:
|
||||
_file_handle.write(s)
|
||||
return n
|
||||
|
||||
def flush(self) -> None: # type: ignore[override]
|
||||
with _lock:
|
||||
self._stream.flush()
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def encoding(self): # type: ignore[override]
|
||||
return getattr(self._stream, "encoding", None)
|
||||
|
||||
@property
|
||||
def errors(self): # type: ignore[override]
|
||||
return getattr(self._stream, "errors", None)
|
||||
|
||||
def isatty(self): # type: ignore[override]
|
||||
return getattr(self._stream, "isatty", lambda: False)()
|
||||
|
||||
def fileno(self): # type: ignore[override]
|
||||
if hasattr(self._stream, "fileno"):
|
||||
return self._stream.fileno()
|
||||
raise OSError("Underlying stream has no fileno")
|
||||
|
||||
|
||||
def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
|
||||
"""
|
||||
Prepare the debug log.
|
||||
|
||||
Creates the output directory, handles append/truncate logic based on cfg, and opens
|
||||
the debug log file for subsequent writes via file-only handlers.
|
||||
"""
|
||||
global _file_handle, _log_path, _tee_installed
|
||||
|
||||
with _lock:
|
||||
# If already initialized, reuse existing path
|
||||
if _log_path is not None:
|
||||
return _log_path
|
||||
|
||||
output_dir = cfg.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_path = Path(output_dir) / filename
|
||||
append = bool(
|
||||
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
|
||||
)
|
||||
|
||||
if not append and log_path.exists():
|
||||
log_path.unlink()
|
||||
|
||||
fh = open(log_path, "a", encoding="utf-8")
|
||||
fh.flush()
|
||||
|
||||
_file_handle = fh
|
||||
_log_path = str(log_path)
|
||||
|
||||
# Install a tee so stdout/stderr are mirrored to the debug file
|
||||
# Allow disabling via env for testing or advanced usage.
|
||||
tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in {
|
||||
"0",
|
||||
"false",
|
||||
"no",
|
||||
}
|
||||
if tee_enabled and not _tee_installed:
|
||||
# Save originals so we can restore later (e.g., tests)
|
||||
global _orig_stdout, _orig_stderr
|
||||
_orig_stdout = sys.stdout
|
||||
_orig_stderr = sys.stderr
|
||||
sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))
|
||||
sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))
|
||||
_tee_installed = True
|
||||
|
||||
return _log_path
|
||||
|
||||
|
||||
def close_debug_log() -> None:
|
||||
"""Flush and close the debug log and uninstall the stdout/stderr tee.
|
||||
|
||||
Safe to call even if not initialized.
|
||||
"""
|
||||
global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr
|
||||
with _lock:
|
||||
# Restore original stdout/stderr if we installed a tee
|
||||
if _tee_installed:
|
||||
if _orig_stdout is not None:
|
||||
sys.stdout = _orig_stdout
|
||||
if _orig_stderr is not None:
|
||||
sys.stderr = _orig_stderr
|
||||
_tee_installed = False
|
||||
_orig_stdout = None
|
||||
_orig_stderr = None
|
||||
|
||||
# Close the file handle if open
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
_file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
_file_handle = None
|
||||
_log_path = None
|
||||
@@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
||||
if checkpoints:
|
||||
last_checkpoint = str(checkpoints[-1])
|
||||
if not update:
|
||||
LOG.info(f"Resuming from last checkpoint at {last_checkpoint}")
|
||||
return last_checkpoint
|
||||
|
||||
if (
|
||||
@@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
||||
):
|
||||
cfg.resume_from_checkpoint = last_checkpoint
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
"Using auto-resume functionality to resume from checkpoint at "
|
||||
f"{cfg.resume_from_checkpoint}"
|
||||
)
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
@@ -655,15 +655,6 @@ def prepare_optim_env(cfg):
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||
|
||||
|
||||
def prepare_opinionated_env(cfg):
|
||||
if cfg.qlora_sharded_model_loading:
|
||||
# model loading is forked after the tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if cfg.sample_packing:
|
||||
# multipack parallel packing sampler defaults to using fork
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
|
||||
103
tests/test_logging_config_file_capture.py
Normal file
103
tests/test_logging_config_file_capture.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import logging
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def read(path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_logging_state():
|
||||
# Ensure a clean slate for logging between tests
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.shutdown()
|
||||
# Note: dictConfig in configure_logging will set up handlers again
|
||||
yield
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
logging.shutdown()
|
||||
|
||||
|
||||
def test_axolotl_logs_captured_at_all_levels(monkeypatch):
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils import tee
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
# Avoid stdout tee in this test to simplify interaction with pytest capture
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||
configure_logging()
|
||||
path = tee.prepare_debug_log(
|
||||
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||
)
|
||||
|
||||
log = get_logger("axolotl.test")
|
||||
log.info("AX-INFO")
|
||||
log.debug("AX-DEBUG")
|
||||
tee.file_only_stream.flush()
|
||||
|
||||
data = read(path)
|
||||
assert "AX-INFO" in data
|
||||
assert "AX-DEBUG" in data
|
||||
tee.close_debug_log()
|
||||
|
||||
|
||||
def test_third_party_logs_filtered_and_warning_captured(monkeypatch):
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils import tee
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||
configure_logging()
|
||||
path = tee.prepare_debug_log(
|
||||
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||
)
|
||||
|
||||
# Third-party logger (non-axolotl)
|
||||
other = logging.getLogger("thirdparty.lib")
|
||||
other.info("TP-INFO")
|
||||
other.warning("TP-WARN")
|
||||
|
||||
# Simulate Python warnings routed through logging
|
||||
logging.getLogger("py.warnings").warning("PY-WARN")
|
||||
|
||||
# Push through buffers
|
||||
tee.file_only_stream.flush()
|
||||
|
||||
data = read(path)
|
||||
# INFO from non-axolotl should be filtered out (not present)
|
||||
assert "TP-INFO" not in data
|
||||
# WARNING+ should be present
|
||||
assert "TP-WARN" in data
|
||||
# Python warnings captured (via py.warnings logger)
|
||||
assert "PY-WARN" in data
|
||||
tee.close_debug_log()
|
||||
tee.close_debug_log()
|
||||
|
||||
|
||||
def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch):
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils import tee
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||
configure_logging()
|
||||
cfg = type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||
p1 = tee.prepare_debug_log(cfg)
|
||||
p2 = tee.prepare_debug_log(cfg)
|
||||
assert p1 == p2
|
||||
|
||||
log = get_logger("axolotl.test")
|
||||
marker = "UNIQUE-MARKER-12345"
|
||||
log.info(marker)
|
||||
tee.file_only_stream.flush()
|
||||
|
||||
data = read(p1)
|
||||
# Ensure the marker appears once (not duplicated via propagation)
|
||||
assert data.count(marker) == 1
|
||||
tee.close_debug_log()
|
||||
107
tests/test_utils_tee.py
Normal file
107
tests/test_utils_tee.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
||||
def _dummy_cfg(output_dir: str, append: bool = False):
|
||||
# Minimal object with attributes used by prepare_debug_log
|
||||
class Cfg:
|
||||
def __init__(self, out, append):
|
||||
self.output_dir = out
|
||||
self._append = append
|
||||
|
||||
def get(self, key, default=None):
|
||||
if key in {"resume_from_checkpoint", "auto_resume_from_checkpoints"}:
|
||||
return self._append
|
||||
return default
|
||||
|
||||
return Cfg(output_dir, append)
|
||||
|
||||
|
||||
def read(path: str) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def test_file_only_stream_writes_after_prepare(monkeypatch):
|
||||
from axolotl.utils import tee
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
# Avoid stdout tee in this test
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||
cfg = _dummy_cfg(td, append=False)
|
||||
|
||||
# before prepare: writing to file_only_stream creates no file
|
||||
tee.file_only_stream.write("before\n")
|
||||
tee.file_only_stream.flush()
|
||||
assert not os.path.exists(os.path.join(td, "debug.log"))
|
||||
|
||||
# prepare and write
|
||||
path = tee.prepare_debug_log(cfg)
|
||||
assert os.path.basename(path) == "debug.log"
|
||||
tee.file_only_stream.write("hello\n")
|
||||
tee.file_only_stream.flush()
|
||||
|
||||
content = read(path)
|
||||
assert "hello" in content
|
||||
|
||||
tee.close_debug_log()
|
||||
|
||||
|
||||
def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch):
|
||||
from axolotl.utils import tee
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
cfg = _dummy_cfg(td, append=False)
|
||||
try:
|
||||
# Install tee while capture is disabled so stdout tee wraps real stdout.
|
||||
with capsys.disabled():
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "1")
|
||||
path = tee.prepare_debug_log(cfg)
|
||||
import sys
|
||||
|
||||
print("printed-line")
|
||||
sys.stdout.flush()
|
||||
|
||||
# Now verify file contains the line
|
||||
content = read(path)
|
||||
assert "printed-line" in content
|
||||
finally:
|
||||
tee.close_debug_log()
|
||||
|
||||
|
||||
def test_truncate_vs_append_behavior(monkeypatch):
|
||||
from axolotl.utils import tee
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
# Avoid stdout tee in this test
|
||||
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||
# First run creates file with A
|
||||
cfg = _dummy_cfg(td, append=False)
|
||||
_ = tee.prepare_debug_log(cfg)
|
||||
try:
|
||||
tee.file_only_stream.write("A\n")
|
||||
tee.file_only_stream.flush()
|
||||
finally:
|
||||
tee.close_debug_log()
|
||||
|
||||
# Second run with append=False truncates
|
||||
cfg2 = _dummy_cfg(td, append=False)
|
||||
path2 = tee.prepare_debug_log(cfg2)
|
||||
try:
|
||||
tee.file_only_stream.write("B\n")
|
||||
tee.file_only_stream.flush()
|
||||
content = read(path2)
|
||||
assert "A\n" not in content and "B\n" in content
|
||||
finally:
|
||||
tee.close_debug_log()
|
||||
|
||||
# Third run with append=True preserves existing
|
||||
cfg3 = _dummy_cfg(td, append=True)
|
||||
path3 = tee.prepare_debug_log(cfg3)
|
||||
try:
|
||||
tee.file_only_stream.write("C\n")
|
||||
tee.file_only_stream.flush()
|
||||
content = read(path3)
|
||||
assert "B\n" in content and "C\n" in content
|
||||
finally:
|
||||
tee.close_debug_log()
|
||||
Reference in New Issue
Block a user