Debug log, logging improvements (#3159)
* simplify logging * remove comment * progress on debug.log * add debug-level logger for file log * simplify * case insensitivity; 3rd party logging improvements * simplify * fix * tests * lint * nits * nit * Update tests/test_utils_tee.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * cleanup / comments * fix * oops --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,5 +4,7 @@ import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -23,7 +23,8 @@ from axolotl.utils.config import (
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.tee import prepare_debug_log
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -227,8 +228,11 @@ def load_cfg(
|
||||
},
|
||||
)
|
||||
|
||||
# NOTE(djsaunde): We start outputting to output_dir/debug.log at this point since we
|
||||
# have to wait for cfg.output to be resolved. We could call this earlier if we write
|
||||
# to a temporary file, and then move it later.
|
||||
prepare_debug_log(cfg)
|
||||
prepare_optim_env(cfg)
|
||||
prepare_opinionated_env(cfg)
|
||||
normalize_config(cfg)
|
||||
normalize_cfg_datasets(cfg)
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -241,7 +245,6 @@ def load_cfg(
|
||||
for k, v in cfg.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
LOG.info(
|
||||
"config:\n%s",
|
||||
json.dumps(cfg_to_log, indent=2, default=str, sort_keys=True),
|
||||
|
||||
@@ -17,8 +17,6 @@ from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.cli.utils.diffusion import (
|
||||
diffusion_inference,
|
||||
launch_diffusion_gradio_ui,
|
||||
render_html,
|
||||
run_diffusion,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
@@ -26,7 +26,7 @@ from axolotl.cli.utils import (
|
||||
launch_training,
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
@@ -44,7 +44,7 @@ def cli():
|
||||
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||
print_axolotl_text_art()
|
||||
load_dotenv()
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -60,7 +60,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
config: Path to `axolotl` config YAML file.
|
||||
kwargs: Additional keyword arguments to override config file values.
|
||||
"""
|
||||
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from colorama import Fore, Style
|
||||
|
||||
from axolotl.integrations.diffusion import generate, resolve_mask_token_id
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
"""
|
||||
Common logging module for axolotl
|
||||
"""
|
||||
"""Common logging module for axolotl."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from logging import Formatter, Logger, LogRecord
|
||||
from logging.config import dictConfig
|
||||
from typing import Any, Dict
|
||||
@@ -17,9 +14,9 @@ DEFAULT_LOG_LEVEL = "WARNING"
|
||||
|
||||
class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
"""
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
|
||||
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
|
||||
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
|
||||
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL). Allows axolotl.* at
|
||||
INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL). Drops all other records
|
||||
(i.e. non-axolotl.INFO, DEBUG, etc. by default).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -52,13 +49,12 @@ class AxolotlOrWarnErrorFilter(logging.Filter):
|
||||
|
||||
|
||||
class AxolotlLogger(Logger):
|
||||
"""A Logger that automatically rejects non-axolotl INFOs."""
|
||||
"""Logger that applies filtering to non-axolotl loggers."""
|
||||
|
||||
def __init__(self, name: str, level: int = logging.NOTSET):
|
||||
super().__init__(name, level)
|
||||
|
||||
# set global filter on the logger itself
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
if not name.startswith("axolotl"):
|
||||
self.addFilter(AxolotlOrWarnErrorFilter())
|
||||
|
||||
|
||||
class ColorfulFormatter(Formatter):
|
||||
@@ -74,6 +70,7 @@ class ColorfulFormatter(Formatter):
|
||||
|
||||
def format(self, record):
|
||||
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
record.rank_fmt = f" [RANK:{record.rank}]" if record.rank != 0 else ""
|
||||
log_message = super().format(record)
|
||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||
|
||||
@@ -87,32 +84,54 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
},
|
||||
"colorful": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
"concise": {
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
|
||||
},
|
||||
"concise_color": {
|
||||
"()": ColorfulFormatter,
|
||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s]%(rank_fmt)s %(message)s",
|
||||
},
|
||||
},
|
||||
"filters": {
|
||||
"ax_or_warn": {
|
||||
"()": "axolotl.logging_config.AxolotlOrWarnErrorFilter",
|
||||
},
|
||||
},
|
||||
"filters": {},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "simple",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"color_console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "colorful",
|
||||
"filters": [],
|
||||
"stream": sys.stdout,
|
||||
"formatter": "concise_color",
|
||||
"filters": ["ax_or_warn"],
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"ax_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
"root_file_only": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"formatter": "simple",
|
||||
"stream": "ext://axolotl.utils.tee.file_only_stream",
|
||||
},
|
||||
},
|
||||
# log level will be superseded by the AxolotlLogger
|
||||
"root": {
|
||||
"handlers": ["console"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
|
||||
"handlers": ["console", "root_file_only"],
|
||||
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL).upper(),
|
||||
},
|
||||
"loggers": {
|
||||
"axolotl": {
|
||||
"handlers": ["color_console"],
|
||||
"handlers": ["color_console", "ax_file_only"],
|
||||
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL).upper(),
|
||||
"propagate": False,
|
||||
},
|
||||
@@ -123,9 +142,15 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||
def configure_logging():
|
||||
"""Configure with default logging"""
|
||||
init() # Initialize colorama
|
||||
|
||||
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||
logging.setLoggerClass(AxolotlLogger)
|
||||
|
||||
# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
# Route Python warnings through logging so they reach file handlers
|
||||
logging.captureWarnings(True)
|
||||
|
||||
# Set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
|
||||
if "ACCELERATE_LOG_LEVEL" not in os.environ:
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
|
||||
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv(
|
||||
"LOG_LEVEL", DEFAULT_LOG_LEVEL
|
||||
).upper()
|
||||
|
||||
@@ -41,7 +41,7 @@ def patch_evaluation_loop():
|
||||
"""Patch the evaluation_loop method."""
|
||||
# Check if already patched
|
||||
if hasattr(Trainer, "_original_evaluation_loop"):
|
||||
LOG.info("Trainer.evaluation_loop already patched")
|
||||
LOG.debug("Trainer.evaluation_loop already patched")
|
||||
return
|
||||
|
||||
# Check if the patterns exist
|
||||
@@ -84,7 +84,7 @@ def patch_evaluation_loop():
|
||||
)
|
||||
exec(evaluation_loop_source, globals())
|
||||
|
||||
LOG.info("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
LOG.debug("Patched Trainer.evaluation_loop with nanmean loss calculation")
|
||||
Trainer.evaluation_loop = axolotl_evaluation_loop
|
||||
|
||||
|
||||
@@ -135,5 +135,5 @@ def patch_maybe_log_save_evaluate():
|
||||
)
|
||||
exec(maybe_log_source, globals())
|
||||
|
||||
LOG.info("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
LOG.debug("Patched Trainer._maybe_log_save_evaluate with nanmean loss calculation")
|
||||
Trainer._maybe_log_save_evaluate = axolotl_maybe_log_save_evaluate
|
||||
|
||||
@@ -196,10 +196,11 @@ def execute_training(
|
||||
)
|
||||
)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||
# if cfg.bf16:
|
||||
# torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
LOG.info("Starting trainer...")
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
@@ -44,15 +44,6 @@ def set_pytorch_cuda_alloc_conf():
|
||||
)
|
||||
|
||||
|
||||
def patch_optimized_env():
|
||||
"""
|
||||
Patch environment variables to improve VRAM usage and increase download speed
|
||||
"""
|
||||
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
|
||||
def get_not_null(value, default=None):
|
||||
"""
|
||||
return the value if it's not None, otherwise return the default value
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
@@ -40,10 +39,6 @@ class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
|
||||
|
||||
def get_logger(name: str, log_level: str | None = None) -> MultiProcessAdapter:
|
||||
if log_level is None:
|
||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||
logger = logging.getLogger(name)
|
||||
if log_level is not None:
|
||||
logger.setLevel(log_level.upper())
|
||||
logger.root.setLevel(log_level.upper())
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return MultiProcessAdapter(logger, extra={})
|
||||
|
||||
166
src/axolotl/utils/tee.py
Normal file
166
src/axolotl/utils/tee.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Utilities for managing the debug log file and providing a file-only stream for logging
|
||||
handlers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import TextIO, cast
|
||||
|
||||
_lock = threading.Lock()
|
||||
_file_handle: io.TextIOWrapper | None = None
|
||||
_log_path: str | None = None
|
||||
_tee_installed: bool = False
|
||||
_orig_stdout: TextIO | None = None
|
||||
_orig_stderr: TextIO | None = None
|
||||
|
||||
|
||||
class _FileOnlyWriter(io.TextIOBase):
|
||||
"""A stream-like object that writes only to the tee file.
|
||||
|
||||
Before the file is prepared, writes are dropped (no-op).
|
||||
"""
|
||||
|
||||
def write(self, s: str) -> int: # type: ignore[override]
|
||||
with _lock:
|
||||
if _file_handle is not None:
|
||||
_file_handle.write(s)
|
||||
return len(s)
|
||||
return len(s)
|
||||
|
||||
def flush(self) -> None: # type: ignore[override]
|
||||
with _lock:
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
file_only_stream: io.TextIOBase = _FileOnlyWriter()
|
||||
|
||||
|
||||
class _StreamTee(io.TextIOBase):
|
||||
"""A minimal tee that mirrors writes to the debug log file.
|
||||
|
||||
Installed only after the debug log is prepared; no buffering.
|
||||
"""
|
||||
|
||||
def __init__(self, stream: io.TextIOBase):
|
||||
self._stream = stream
|
||||
|
||||
def write(self, s: str) -> int: # type: ignore[override]
|
||||
with _lock:
|
||||
n = self._stream.write(s)
|
||||
if _file_handle is not None:
|
||||
_file_handle.write(s)
|
||||
return n
|
||||
|
||||
def flush(self) -> None: # type: ignore[override]
|
||||
with _lock:
|
||||
self._stream.flush()
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@property
|
||||
def encoding(self): # type: ignore[override]
|
||||
return getattr(self._stream, "encoding", None)
|
||||
|
||||
@property
|
||||
def errors(self): # type: ignore[override]
|
||||
return getattr(self._stream, "errors", None)
|
||||
|
||||
def isatty(self): # type: ignore[override]
|
||||
return getattr(self._stream, "isatty", lambda: False)()
|
||||
|
||||
def fileno(self): # type: ignore[override]
|
||||
if hasattr(self._stream, "fileno"):
|
||||
return self._stream.fileno()
|
||||
raise OSError("Underlying stream has no fileno")
|
||||
|
||||
|
||||
def prepare_debug_log(cfg, filename: str = "debug.log") -> str:
|
||||
"""
|
||||
Prepare the debug log.
|
||||
|
||||
Creates the output directory, handles append/truncate logic based on cfg, and opens
|
||||
the debug log file for subsequent writes via file-only handlers.
|
||||
"""
|
||||
global _file_handle, _log_path, _tee_installed
|
||||
|
||||
with _lock:
|
||||
# If already initialized, reuse existing path
|
||||
if _log_path is not None:
|
||||
return _log_path
|
||||
|
||||
output_dir = cfg.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_path = Path(output_dir) / filename
|
||||
append = bool(
|
||||
cfg.get("resume_from_checkpoint") or cfg.get("auto_resume_from_checkpoints")
|
||||
)
|
||||
|
||||
if not append and log_path.exists():
|
||||
log_path.unlink()
|
||||
|
||||
fh = open(log_path, "a", encoding="utf-8")
|
||||
fh.flush()
|
||||
|
||||
_file_handle = fh
|
||||
_log_path = str(log_path)
|
||||
|
||||
# Install a tee so stdout/stderr are mirrored to the debug file
|
||||
# Allow disabling via env for testing or advanced usage.
|
||||
tee_enabled = os.getenv("AXOLOTL_TEE_STDOUT", "1").lower() not in {
|
||||
"0",
|
||||
"false",
|
||||
"no",
|
||||
}
|
||||
if tee_enabled and not _tee_installed:
|
||||
# Save originals so we can restore later (e.g., tests)
|
||||
global _orig_stdout, _orig_stderr
|
||||
_orig_stdout = sys.stdout
|
||||
_orig_stderr = sys.stderr
|
||||
sys.stdout = _StreamTee(cast(io.TextIOBase, sys.stdout))
|
||||
sys.stderr = _StreamTee(cast(io.TextIOBase, sys.stderr))
|
||||
_tee_installed = True
|
||||
|
||||
return _log_path
|
||||
|
||||
|
||||
def close_debug_log() -> None:
|
||||
"""Flush and close the debug log and uninstall the stdout/stderr tee.
|
||||
|
||||
Safe to call even if not initialized.
|
||||
"""
|
||||
global _file_handle, _log_path, _tee_installed, _orig_stdout, _orig_stderr
|
||||
with _lock:
|
||||
# Restore original stdout/stderr if we installed a tee
|
||||
if _tee_installed:
|
||||
if _orig_stdout is not None:
|
||||
sys.stdout = _orig_stdout
|
||||
if _orig_stderr is not None:
|
||||
sys.stderr = _orig_stderr
|
||||
_tee_installed = False
|
||||
_orig_stdout = None
|
||||
_orig_stderr = None
|
||||
|
||||
# Close the file handle if open
|
||||
if _file_handle is not None:
|
||||
try:
|
||||
_file_handle.flush()
|
||||
_file_handle.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
_file_handle = None
|
||||
_log_path = None
|
||||
@@ -31,6 +31,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
||||
if checkpoints:
|
||||
last_checkpoint = str(checkpoints[-1])
|
||||
if not update:
|
||||
LOG.info(f"Resuming from last checkpoint at {last_checkpoint}")
|
||||
return last_checkpoint
|
||||
|
||||
if (
|
||||
@@ -40,6 +41,7 @@ def determine_last_checkpoint(cfg: DictDefault, update: bool = True) -> str | No
|
||||
):
|
||||
cfg.resume_from_checkpoint = last_checkpoint
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
"Using auto-resume functionality to resume from checkpoint at "
|
||||
f"{cfg.resume_from_checkpoint}"
|
||||
)
|
||||
return cfg.resume_from_checkpoint
|
||||
|
||||
@@ -655,15 +655,6 @@ def prepare_optim_env(cfg):
|
||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
|
||||
|
||||
|
||||
def prepare_opinionated_env(cfg):
|
||||
if cfg.qlora_sharded_model_loading:
|
||||
# model loading is forked after the tokenizer
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if cfg.sample_packing:
|
||||
# multipack parallel packing sampler defaults to using fork
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def setup_trainer(
|
||||
cfg,
|
||||
train_dataset,
|
||||
|
||||
Reference in New Issue
Block a user