Debug log, logging improvements (#3159)
* simplify logging * remove comment * progress on debug.log * add debug-level logger for file log * simplify * case insensitivity; 3rd party logging improvements * simplify * fix * tests * lint * nits * nit * Update tests/test_utils_tee.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * cleanup / comments * fix * oops --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -190,3 +190,6 @@ out/
|
|||||||
|
|
||||||
# vim
|
# vim
|
||||||
*.swp
|
*.swp
|
||||||
|
|
||||||
|
# scm auto-versioning
|
||||||
|
src/axolotl/_version.py
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ repos:
|
|||||||
rev: v0.12.12
|
rev: v0.12.12
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff
|
||||||
args: [--fix, --select, I]
|
args: [--fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.17.1
|
rev: v1.17.1
|
||||||
|
|||||||
@@ -251,10 +251,10 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from axolotl.utils import patch_optimized_env\n",
|
"from axolotl.utils import set_pytorch_cuda_alloc_conf\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# speedup downloads from HF 🤗 and set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
"# Set \"PYTORCH_CUDA_ALLOC_CONF\" env to save memory\n",
|
||||||
"patch_optimized_env()"
|
"set_pytorch_cuda_alloc_conf()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ line-length = 88
|
|||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "W", "C90", "B"]
|
select = ["E", "F", "W", "C90", "B", "I"]
|
||||||
ignore = [
|
ignore = [
|
||||||
"E203", # Whitespace before ':'
|
"E203", # Whitespace before ':'
|
||||||
"E501", # Line too long
|
"E501", # Line too long
|
||||||
|
|||||||
@@ -4,5 +4,7 @@ import os
|
|||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||||
|
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ from axolotl.utils.config import (
|
|||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
from axolotl.utils.tee import prepare_debug_log
|
||||||
|
from axolotl.utils.trainer import prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -227,8 +228,11 @@ def load_cfg(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||||
|
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||||
|
# to a temporary file, and then move it later.
|
||||||
|
prepare_debug_log(cfg)
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
prepare_opinionated_env(cfg)
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
normalize_cfg_datasets(cfg)
|
normalize_cfg_datasets(cfg)
|
||||||
setup_wandb_env_vars(cfg)
|
setup_wandb_env_vars(cfg)
|
||||||
@@ -241,7 +245,6 @@ def load_cfg(
|
|||||||
for k, v in cfg.items()
|
for k, v in cfg.items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"config:\n%s",
|
"config:\n%s",
|
||||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||||
|
|||||||
@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
|
|||||||
from axolotl.cli.utils.diffusion import (
|
from axolotl.cli.utils.diffusion import (
|
||||||
diffusion_inference,
|
diffusion_inference,
|
||||||
launch_diffusion_gradio_ui,
|
launch_diffusion_gradio_ui,
|
||||||
render_html,
|
|
||||||
run_diffusion,
|
|
||||||
)
|
)
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
|||||||
launch_training,
|
launch_training,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import patch_optimized_env
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ def cli():
|
|||||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
patch_optimized_env()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
config: Path to `axolotl` config YAML file.
|
config: Path to `axolotl` config YAML file.
|
||||||
kwargs: Additional keyword arguments to override config file values.
|
kwargs: Additional keyword arguments to override config file values.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parser = HfArgumentParser(TrainerCliArgs)
|
parser = HfArgumentParser(TrainerCliArgs)
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
|
||||||
from colorama import Fore, Style
|
from colorama import Fore, Style
|
||||||
|
|
||||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
"""
|
"""Common logging module for axolotl."""
|
||||||
Common logging module for axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
from logging import Formatter, Logger, LogRecord
|
from logging import Formatter, Logger, LogRecord
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
|
|||||||
|
|
||||||
class AxolotlOrWarnErrorFilter(logging.Filter):
|
class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||||
"""
|
"""
|
||||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
|
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
|
||||||
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
|
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
|
||||||
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
|
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
|||||||
|
|
||||||
|
|
||||||
class AxolotlLogger(Logger):
|
class AxolotlLogger(Logger):
|
||||||
"""A Logger that automatically rejects non-axolotl INFOs."""
|
"""Logger that applies filtering to non-axolotl loggers."""
|
||||||
|
|
||||||
def __init__(self, name: str, level: int = logging.NOTSET):
|
def __init__(self, name: str, level: int = logging.NOTSET):
|
||||||
super().__init__(name, level)
|
super().__init__(name, level)
|
||||||
|
if not name.startswith("axolotl"):
|
||||||
# set global filter on the logger itself
|
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
|
||||||
|
|
||||||
|
|
||||||
class ColorfulFormatter(Formatter):
|
class ColorfulFormatter(Formatter):
|
||||||
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
|
|||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
|
||||||
log_message = super().format(record)
|
log_message = super().format(record)
|
||||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||||
|
|
||||||
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
},
|
},
|
||||||
"colorful": {
|
"colorful": {
|
||||||
"()": ColorfulFormatter,
|
"()": ColorfulFormatter,
|
||||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
|
||||||
|
},
|
||||||
|
"concise": {
|
||||||
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
||||||
|
},
|
||||||
|
"concise_color": {
|
||||||
|
"()": ColorfulFormatter,
|
||||||
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"filters": {
|
||||||
|
"ax_or_warn": {
|
||||||
|
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"filters": {},
|
|
||||||
"handlers": {
|
"handlers": {
|
||||||
"console": {
|
"console": {
|
||||||
"class": "logging.StreamHandler",
|
"class": "logging.StreamHandler",
|
||||||
"formatter": "simple",
|
"formatter": "concise",
|
||||||
"filters": [],
|
"filters": ["ax_or_warn"],
|
||||||
"stream": sys.stdout,
|
"stream": "ext://sys.stdout",
|
||||||
},
|
},
|
||||||
"color_console": {
|
"color_console": {
|
||||||
"class": "logging.StreamHandler",
|
"class": "logging.StreamHandler",
|
||||||
"formatter": "colorful",
|
"formatter": "concise_color",
|
||||||
"filters": [],
|
"filters": ["ax_or_warn"],
|
||||||
"stream": sys.stdout,
|
"stream": "ext://sys.stdout",
|
||||||
|
},
|
||||||
|
"ax_file_only": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"level": "DEBUG",
|
||||||
|
"formatter": "simple",
|
||||||
|
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||||
|
},
|
||||||
|
"root_file_only": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"level": "DEBUG",
|
||||||
|
"formatter": "simple",
|
||||||
|
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
# log level will be superseded by the AxolotlLogger
|
|
||||||
"root": {
|
"root": {
|
||||||
"handlers": ["console"],
|
"handlers": ["console", "root_file_only"],
|
||||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
|
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
|
||||||
},
|
},
|
||||||
"loggers": {
|
"loggers": {
|
||||||
"axolotl": {
|
"axolotl": {
|
||||||
"handlers": ["color_console"],
|
"handlers": ["color_console", "ax_file_only"],
|
||||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||||
"propagate": False,
|
"propagate": False,
|
||||||
},
|
},
|
||||||
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
def configure_logging():
|
def configure_logging():
|
||||||
"""Configure with default logging"""
|
"""Configure with default logging"""
|
||||||
init() # Initialize colorama
|
init() # Initialize colorama
|
||||||
|
|
||||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||||
logging.setLoggerClass(AxolotlLogger)
|
logging.setLoggerClass(AxolotlLogger)
|
||||||
|
|
||||||
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
# Route Python warnings through logging so they reach file handlers
|
||||||
|
logging.captureWarnings(True)
|
||||||
|
|
||||||
|
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||||
if "ACCELERATE_LOG_LEVEL" not in os.environ:
|
if "ACCELERATE_LOG_LEVEL" not in os.environ:
|
||||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
|
||||||
|
"LOG_LEVEL", DEFAULT_LOG_LEVEL
|
||||||
|
).upper()
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ def patch_evaluation_loop():
|
|||||||
"""Patch the evaluation_loop method."""
|
"""Patch the evaluation_loop method."""
|
||||||
# Check if already patched
|
# Check if already patched
|
||||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||||
LOG.info("Trainer.evaluation_loop already patched")
|
LOG.debug("Trainer.evaluation_loop already patched")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if the patterns exist
|
# Check if the patterns exist
|
||||||
@@ -84,7 +84,7 @@ def patch_evaluation_loop():
|
|||||||
)
|
)
|
||||||
exec(evaluation_loop_source, globals())
|
exec(evaluation_loop_source, globals())
|
||||||
|
|
||||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||||
Trainer.evaluation_loop = axolotl_evaluation_loop
|
Trainer.evaluation_loop = axolotl_evaluation_loop
|
||||||
|
|
||||||
|
|
||||||
@@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate():
|
|||||||
)
|
)
|
||||||
exec(maybe_log_source, globals())
|
exec(maybe_log_source, globals())
|
||||||
|
|
||||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate
|
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate
|
||||||
|
|||||||
@@ -196,10 +196,11 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
|
||||||
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||||
# if cfg.bf16:
|
# if cfg.bf16:
|
||||||
# torch.set_default_dtype(torch.bfloat16)
|
# torch.set_default_dtype(torch.bfloat16)
|
||||||
|
|
||||||
|
LOG.info("Starting trainer...")
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|||||||
@@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_optimized_env():
|
|
||||||
"""
|
|
||||||
Patch environment variables to improve VRAM usage and increase download speed
|
|
||||||
"""
|
|
||||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
|
||||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
|
||||||
set_pytorch_cuda_alloc_conf()
|
|
||||||
|
|
||||||
|
|
||||||
def get_not_null(value, default=None):
|
def get_not_null(value, default=None):
|
||||||
"""
|
"""
|
||||||
return the value if it's not None, otherwise return the default value
|
return the value if it's not None, otherwise return the default value
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
|
|
||||||
from axolotl.utils.distributed import is_main_process
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
@@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter):
|
|||||||
|
|
||||||
|
|
||||||
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
||||||
if log_level is None:
|
|
||||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
if log_level is not None:
|
logger.setLevel(logging.DEBUG)
|
||||||
logger.setLevel(log_level.upper())
|
|
||||||
logger.root.setLevel(log_level.upper())
|
|
||||||
return MultiProcessAdapter(logger, extra={})
|
return MultiProcessAdapter(logger, extra={})
|
||||||
|
|||||||
166
src/axolotl/utils/tee.py
Normal file
166
src/axolotl/utils/tee.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Utilities for managing the debug log file and providing a file-only stream for logging
|
||||||
|
handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TextIO, cast
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_file_handle: io.TextIOWrapper | None = None
|
||||||
|
_log_path: str | None = None
|
||||||
|
_tee_installed: bool = False
|
||||||
|
_orig_stdout: TextIO | None = None
|
||||||
|
_orig_stderr: TextIO | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FileOnlyWriter(io.TextIOBase):
|
||||||
|
"""A stream-like object that writes only to the tee file.
|
||||||
|
|
||||||
|
Before the file is prepared, writes are dropped (no-op).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def write(self, s: str) -> int: # type: ignore[override]
|
||||||
|
with _lock:
|
||||||
|
if _file_handle is not None:
|
||||||
|
_file_handle.write(s)
|
||||||
|
return len(s)
|
||||||
|
return len(s)
|
||||||
|
|
||||||
|
def flush(self) -> None: # type: ignore[override]
|
||||||
|
with _lock:
|
||||||
|
if _file_handle is not None:
|
||||||
|
try:
|
||||||
|
_file_handle.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
file_only_stream: io.TextIOBase = _FileOnlyWriter()
|
||||||
|
|
||||||
|
|
||||||
|
class _StreamTee(io.TextIOBase):
|
||||||
|
"""A minimal tee that mirrors writes to the debug log file.
|
||||||
|
|
||||||
|
Installed only after the debug log is prepared; no buffering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream: io.TextIOBase):
|
||||||
|
self._stream = stream
|
||||||
|
|
||||||
|
def write(self, s: str) -> int: # type: ignore[override]
|
||||||
|
with _lock:
|
||||||
|
n = self._stream.write(s)
|
||||||
|
if _file_handle is not None:
|
||||||
|
_file_handle.write(s)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def flush(self) -> None: # type: ignore[override]
|
||||||
|
with _lock:
|
||||||
|
self._stream.flush()
|
||||||
|
if _file_handle is not None:
|
||||||
|
try:
|
||||||
|
_file_handle.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def encoding(self): # type: ignore[override]
|
||||||
|
return getattr(self._stream, "encoding", None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def errors(self): # type: ignore[override]
|
||||||
|
return getattr(self._stream, "errors", None)
|
||||||
|
|
||||||
|
def isatty(self): # type: ignore[override]
|
||||||
|
return getattr(self._stream, "isatty", lambda: False)()
|
||||||
|
|
||||||
|
def fileno(self): # type: ignore[override]
|
||||||
|
if hasattr(self._stream, "fileno"):
|
||||||
|
return self._stream.fileno()
|
||||||
|
raise OSError("Underlying stream has no fileno")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
|
||||||
|
"""
|
||||||
|
Prepare the debug log.
|
||||||
|
|
||||||
|
Creates the output directory, handles append/truncate logic based on cfg, and opens
|
||||||
|
the debug log file for subsequent writes via file-only handlers.
|
||||||
|
"""
|
||||||
|
global _file_handle, _log_path, _tee_installed
|
||||||
|
|
||||||
|
with _lock:
|
||||||
|
# If already initialized, reuse existing path
|
||||||
|
if _log_path is not None:
|
||||||
|
return _log_path
|
||||||
|
|
||||||
|
output_dir = cfg.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
log_path = Path(output_dir) / filename
|
||||||
|
append = bool(
|
||||||
|
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
|
||||||
|
)
|
||||||
|
|
||||||
|
if not append and log_path.exists():
|
||||||
|
log_path.unlink()
|
||||||
|
|
||||||
|
fh = open(log_path, "a", encoding="utf-8")
|
||||||
|
fh.flush()
|
||||||
|
|
||||||
|
_file_handle = fh
|
||||||
|
_log_path = str(log_path)
|
||||||
|
|
||||||
|
# Install a tee so stdout/stderr are mirrored to the debug file
|
||||||
|
# Allow disabling via env for testing or advanced usage.
|
||||||
|
tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in {
|
||||||
|
"0",
|
||||||
|
"false",
|
||||||
|
"no",
|
||||||
|
}
|
||||||
|
if tee_enabled and not _tee_installed:
|
||||||
|
# Save originals so we can restore later (e.g., tests)
|
||||||
|
global _orig_stdout, _orig_stderr
|
||||||
|
_orig_stdout = sys.stdout
|
||||||
|
_orig_stderr = sys.stderr
|
||||||
|
sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))
|
||||||
|
sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))
|
||||||
|
_tee_installed = True
|
||||||
|
|
||||||
|
return _log_path
|
||||||
|
|
||||||
|
|
||||||
|
def close_debug_log() -> None:
|
||||||
|
"""Flush and close the debug log and uninstall the stdout/stderr tee.
|
||||||
|
|
||||||
|
Safe to call even if not initialized.
|
||||||
|
"""
|
||||||
|
global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr
|
||||||
|
with _lock:
|
||||||
|
# Restore original stdout/stderr if we installed a tee
|
||||||
|
if _tee_installed:
|
||||||
|
if _orig_stdout is not None:
|
||||||
|
sys.stdout = _orig_stdout
|
||||||
|
if _orig_stderr is not None:
|
||||||
|
sys.stderr = _orig_stderr
|
||||||
|
_tee_installed = False
|
||||||
|
_orig_stdout = None
|
||||||
|
_orig_stderr = None
|
||||||
|
|
||||||
|
# Close the file handle if open
|
||||||
|
if _file_handle is not None:
|
||||||
|
try:
|
||||||
|
_file_handle.flush()
|
||||||
|
_file_handle.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
_file_handle = None
|
||||||
|
_log_path = None
|
||||||
@@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
|||||||
if checkpoints:
|
if checkpoints:
|
||||||
last_checkpoint = str(checkpoints[-1])
|
last_checkpoint = str(checkpoints[-1])
|
||||||
if not update:
|
if not update:
|
||||||
|
LOG.info(f"Resuming from last checkpoint at {last_checkpoint}")
|
||||||
return last_checkpoint
|
return last_checkpoint
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
|||||||
):
|
):
|
||||||
cfg.resume_from_checkpoint = last_checkpoint
|
cfg.resume_from_checkpoint = last_checkpoint
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
"Using auto-resume functionality to resume from checkpoint at "
|
||||||
|
f"{cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
return cfg.resume_from_checkpoint
|
return cfg.resume_from_checkpoint
|
||||||
|
|||||||
@@ -655,15 +655,6 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||||
|
|
||||||
|
|
||||||
def prepare_opinionated_env(cfg):
|
|
||||||
if cfg.qlora_sharded_model_loading:
|
|
||||||
# model loading is forked after the tokenizer
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
if cfg.sample_packing:
|
|
||||||
# multipack parallel packing sampler defaults to using fork
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg,
|
cfg,
|
||||||
train_dataset,
|
train_dataset,
|
||||||
|
|||||||
103
tests/test_logging_config_file_capture.py
Normal file
103
tests/test_logging_config_file_capture.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def read(path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_logging_state():
|
||||||
|
# Ensure a clean slate for logging between tests
|
||||||
|
for handler in logging.root.handlers[:]:
|
||||||
|
logging.root.removeHandler(handler)
|
||||||
|
logging.shutdown()
|
||||||
|
# Note: dictConfig in configure_logging will set up handlers again
|
||||||
|
yield
|
||||||
|
for handler in logging.root.handlers[:]:
|
||||||
|
logging.root.removeHandler(handler)
|
||||||
|
logging.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_axolotl_logs_captured_at_all_levels(monkeypatch):
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.utils import tee
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
# Avoid stdout tee in this test to simplify interaction with pytest capture
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||||
|
configure_logging()
|
||||||
|
path = tee.prepare_debug_log(
|
||||||
|
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||||
|
)
|
||||||
|
|
||||||
|
log = get_logger("axolotl.test")
|
||||||
|
log.info("AX-INFO")
|
||||||
|
log.debug("AX-DEBUG")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
|
||||||
|
data = read(path)
|
||||||
|
assert "AX-INFO" in data
|
||||||
|
assert "AX-DEBUG" in data
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
|
||||||
|
def test_third_party_logs_filtered_and_warning_captured(monkeypatch):
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.utils import tee
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||||
|
configure_logging()
|
||||||
|
path = tee.prepare_debug_log(
|
||||||
|
type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||||
|
)
|
||||||
|
|
||||||
|
# Third-party logger (non-axolotl)
|
||||||
|
other = logging.getLogger("thirdparty.lib")
|
||||||
|
other.info("TP-INFO")
|
||||||
|
other.warning("TP-WARN")
|
||||||
|
|
||||||
|
# Simulate Python warnings routed through logging
|
||||||
|
logging.getLogger("py.warnings").warning("PY-WARN")
|
||||||
|
|
||||||
|
# Push through buffers
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
|
||||||
|
data = read(path)
|
||||||
|
# INFO from non-axolotl should be filtered out (not present)
|
||||||
|
assert "TP-INFO" not in data
|
||||||
|
# WARNING+ should be present
|
||||||
|
assert "TP-WARN" in data
|
||||||
|
# Python warnings captured (via py.warnings logger)
|
||||||
|
assert "PY-WARN" in data
|
||||||
|
tee.close_debug_log()
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_debug_log_idempotent_and_no_duplicate(monkeypatch):
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
from axolotl.utils import tee
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||||
|
configure_logging()
|
||||||
|
cfg = type("Cfg", (), {"output_dir": td, "get": lambda *_: False})
|
||||||
|
p1 = tee.prepare_debug_log(cfg)
|
||||||
|
p2 = tee.prepare_debug_log(cfg)
|
||||||
|
assert p1 == p2
|
||||||
|
|
||||||
|
log = get_logger("axolotl.test")
|
||||||
|
marker = "UNIQUE-MARKER-12345"
|
||||||
|
log.info(marker)
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
|
||||||
|
data = read(p1)
|
||||||
|
# Ensure the marker appears once (not duplicated via propagation)
|
||||||
|
assert data.count(marker) == 1
|
||||||
|
tee.close_debug_log()
|
||||||
107
tests/test_utils_tee.py
Normal file
107
tests/test_utils_tee.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_cfg(output_dir: str, append: bool = False):
|
||||||
|
# Minimal object with attributes used by prepare_debug_log
|
||||||
|
class Cfg:
|
||||||
|
def __init__(self, out, append):
|
||||||
|
self.output_dir = out
|
||||||
|
self._append = append
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
if key in {"resume_from_checkpoint", "auto_resume_from_checkpoints"}:
|
||||||
|
return self._append
|
||||||
|
return default
|
||||||
|
|
||||||
|
return Cfg(output_dir, append)
|
||||||
|
|
||||||
|
|
||||||
|
def read(path: str) -> str:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_only_stream_writes_after_prepare(monkeypatch):
|
||||||
|
from axolotl.utils import tee
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
# Avoid stdout tee in this test
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||||
|
cfg = _dummy_cfg(td, append=False)
|
||||||
|
|
||||||
|
# before prepare: writing to file_only_stream creates no file
|
||||||
|
tee.file_only_stream.write("before\n")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
assert not os.path.exists(os.path.join(td, "debug.log"))
|
||||||
|
|
||||||
|
# prepare and write
|
||||||
|
path = tee.prepare_debug_log(cfg)
|
||||||
|
assert os.path.basename(path) == "debug.log"
|
||||||
|
tee.file_only_stream.write("hello\n")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
|
||||||
|
content = read(path)
|
||||||
|
assert "hello" in content
|
||||||
|
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stdout_is_mirrored_after_prepare(capsys, monkeypatch):
|
||||||
|
from axolotl.utils import tee
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
cfg = _dummy_cfg(td, append=False)
|
||||||
|
try:
|
||||||
|
# Install tee while capture is disabled so stdout tee wraps real stdout.
|
||||||
|
with capsys.disabled():
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "1")
|
||||||
|
path = tee.prepare_debug_log(cfg)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
print("printed-line")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
# Now verify file contains the line
|
||||||
|
content = read(path)
|
||||||
|
assert "printed-line" in content
|
||||||
|
finally:
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_vs_append_behavior(monkeypatch):
|
||||||
|
from axolotl.utils import tee
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as td:
|
||||||
|
# Avoid stdout tee in this test
|
||||||
|
monkeypatch.setenv("AXOLOTL_TEE_STDOUT", "0")
|
||||||
|
# First run creates file with A
|
||||||
|
cfg = _dummy_cfg(td, append=False)
|
||||||
|
_ = tee.prepare_debug_log(cfg)
|
||||||
|
try:
|
||||||
|
tee.file_only_stream.write("A\n")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
finally:
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
# Second run with append=False truncates
|
||||||
|
cfg2 = _dummy_cfg(td, append=False)
|
||||||
|
path2 = tee.prepare_debug_log(cfg2)
|
||||||
|
try:
|
||||||
|
tee.file_only_stream.write("B\n")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
content = read(path2)
|
||||||
|
assert "A\n" not in content and "B\n" in content
|
||||||
|
finally:
|
||||||
|
tee.close_debug_log()
|
||||||
|
|
||||||
|
# Third run with append=True preserves existing
|
||||||
|
cfg3 = _dummy_cfg(td, append=True)
|
||||||
|
path3 = tee.prepare_debug_log(cfg3)
|
||||||
|
try:
|
||||||
|
tee.file_only_stream.write("C\n")
|
||||||
|
tee.file_only_stream.flush()
|
||||||
|
content = read(path3)
|
||||||
|
assert "B\n" in content and "C\n" in content
|
||||||
|
finally:
|
||||||
|
tee.close_debug_log()
|
||||||
Reference in New Issue
Block a user