Rank 0-only logging (#2608)

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-05-28 14:57:30 +01:00
committed by GitHub
parent 5fca214108
commit 65c5481120
135 changed files with 454 additions and 378 deletions

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
@@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

View File

@@ -1,6 +1,5 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_accelerate_default_config() -> None:

View File

@@ -1,7 +1,6 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__, use_environ=True)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
LOG.info(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
LOG.info("Invalid input. Please enter a number.")
return chosen_file

View File

@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def get_multi_line_input() -> str:

View File

@@ -2,7 +2,6 @@
# pylint: disable=redefined-outer-name
import logging
import os
import subprocess # nosec B404
import tempfile
@@ -31,8 +30,11 @@ from axolotl.cli.utils import (
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -177,7 +179,7 @@ def train(
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc

View File

@@ -1,6 +1,5 @@
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
@@ -13,8 +12,9 @@ from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
import shutil
from pathlib import Path
@@ -27,8 +26,9 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):

View File

@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -2,7 +2,6 @@
CLI to post-training quantize a model using torchao
"""
import logging
from pathlib import Path
from typing import Union
@@ -11,9 +10,10 @@ from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def do_quantize(

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model."""
import gc
import logging
import os
from pathlib import Path
from typing import Union
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
"""

View File

@@ -4,7 +4,6 @@ import concurrent.futures
import dataclasses
import hashlib
import json
import logging
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -23,8 +22,9 @@ from transformers import (
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None):

View File

@@ -1,6 +1,5 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
@@ -14,10 +13,11 @@ from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
@dataclass

View File

@@ -156,7 +156,6 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids)
]
if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids)
if pending_weight:

View File

@@ -19,7 +19,6 @@ import abc
import importlib
import importlib.util
import inspect
import logging
import math
import os
import sys
@@ -88,6 +87,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try:
@@ -95,7 +95,7 @@ try:
except ImportError:
pass
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class TrainerBuilderBase(abc.ABC):

View File

@@ -4,7 +4,6 @@
from __future__ import annotations
import logging
import os
from collections import defaultdict
from functools import wraps
@@ -34,9 +33,10 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):

View File

@@ -2,7 +2,6 @@
import importlib
import inspect
import logging
from typing import Any
from trl.trainer.grpo_trainer import RewardFunc
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOTrainer,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class GRPOStrategy:

View File

@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class OptimizerMixin(Trainer):

View File

@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class RngLoaderMixin(Trainer):

View File

@@ -1,12 +1,11 @@
"""Module for Axolotl trainer scheduler mixin"""
import logging
import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import (
RexLR,
get_cosine_schedule_with_min_lr,
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant,
)
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class SchedulerMixin(Trainer):
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR(
optimizer=optimizer,
max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality"""
import logging
import os
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset):

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections
import importlib
import logging
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -31,6 +30,9 @@ from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta
@@ -331,12 +333,12 @@ class PluginManager:
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
LOG.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'

View File

@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
LOG = get_logger(__name__, use_environ=True)
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch,
)
if is_main_process(use_environ=True):
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -15,12 +15,13 @@
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CutCrossEntropyArgs(BaseModel):

View File

@@ -2,15 +2,15 @@
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
LOG = get_logger(__name__)
class GrokfastCallbackHandler(TrainerCallback):

View File

@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger")
LOG = get_logger(__name__, use_environ=True)
class LigerPlugin(BasePlugin):
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
if is_main_process(use_environ=True):
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
@@ -186,6 +182,6 @@ class LigerPlugin(BasePlugin):
swiglu=cfg.liger_glu_activation,
)
else:
logging.warning(
LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -15,12 +15,13 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class LigerArgs(BaseModel):

View File

@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
by maintaining masks for zero weights during training.
"""
import logging
from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
LOG = get_logger(__name__)
class LLMCompressorCallbackHandler(TrainerCallback):

View File

@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
import json
import logging
import requests
from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__)
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {}
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
except FileNotFoundError:
pass
except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data:
try:
snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return
# also catch json parsing errors
except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return
unfrozen_parameters = _generate_unfrozen_params_yaml(

View File

@@ -1,6 +1,5 @@
"""Adapter loading functionality, including LoRA / QLoRA and associated utils"""
import logging
import os
import types
from typing import Any
@@ -21,8 +20,9 @@ from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module):

View File

@@ -3,7 +3,6 @@ models.
"""
import gc
import logging
import math
import os
from functools import cached_property
@@ -47,10 +46,11 @@ from axolotl.utils.distributed import (
get_device_count,
get_device_type,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -4,7 +4,6 @@ Applies pre- and post-model load patches for various fixes and optimizations.
"""
import importlib.util
import logging
from functools import cached_property
import addict
@@ -17,8 +16,9 @@ from axolotl.monkeypatch.multipack import (
patch_for_multipack,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -1,6 +1,5 @@
"""Processor loading functionality for multi-modal models"""
import logging
from typing import Any
import transformers
@@ -10,8 +9,9 @@ from transformers import (
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):

View File

@@ -1,7 +1,6 @@
"""Tokenizer loading functionality and associated utils"""
import json
import logging
import os
import transformers
@@ -19,8 +18,9 @@ from axolotl.utils.distributed import (
is_local_main_process,
is_main_process,
)
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -1,7 +1,6 @@
"""Utilities for axolotl.loaders module"""
import contextlib
import logging
from typing import Type
import addict
@@ -9,8 +8,9 @@ import torch
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def get_module_class_from_name(

View File

@@ -2,12 +2,13 @@
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
"""
import logging
import sys
import torch
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):

View File

@@ -3,7 +3,6 @@ Flash attention monkey patch for cerebras btlm model
"""
import importlib
import logging
from typing import Optional, Tuple
import torch
@@ -11,7 +10,9 @@ from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):

View File

@@ -18,7 +18,6 @@ DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
import atexit
import concurrent.futures
import logging
import os
import queue
import shutil
@@ -32,11 +31,13 @@ from typing import Dict
import torch
from axolotl.utils.logging import get_logger
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger
logger = logging.getLogger(__name__)
logger = get_logger(__name__)
class DiskOffloadManager:

View File

@@ -2,7 +2,6 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings
from typing import List, Optional, Tuple, Union
@@ -25,6 +24,7 @@ from transformers.models.llama.modeling_llama import (
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.utils.logging import get_logger
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -41,7 +41,7 @@ except ImportError:
)
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
def is_xformers_available() -> bool:
@@ -612,9 +612,10 @@ def generate_qkv(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -627,9 +628,10 @@ def generate_qkv(
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)

View File

@@ -2,7 +2,6 @@
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import warnings
from typing import Optional, Tuple
@@ -11,10 +10,14 @@ import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try:
import xformers.ops
except ImportError:
logging.error("xformers not found! Please install it before trying to use it.")
LOG.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention():

View File

@@ -7,7 +7,6 @@ import types
from typing import Generator, Tuple, Type
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
@@ -20,6 +19,7 @@ from axolotl.kernels.lora import (
)
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)

View File

@@ -2,7 +2,6 @@
# pylint: disable=duplicate-code
import logging
from functools import partial
from typing import List, Optional, Tuple, Union
@@ -28,8 +27,9 @@ from transformers.models.mistral.modeling_mistral import (
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
LOG = get_logger(__name__)
def replace_mistral_attn_with_flash_attn(
@@ -359,9 +359,10 @@ def generate_qkv(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
def output_pad_fn(output_unpad):
return pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -374,9 +375,10 @@ def generate_qkv(
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
def output_pad_fn(output_unpad):
return rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)

View File

@@ -3,14 +3,14 @@ Patch prepare_model_for_kbit_training to not upcast everything
"""
import inspect
import logging
import peft
import axolotl
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
ORIGINAL_PREPARE_CODE = """
for param in model.parameters():

View File

@@ -2,7 +2,6 @@
import glob
import json
import logging
import os.path
import shutil
from functools import partial
@@ -27,8 +26,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.relora")
LOG = get_logger(__name__)
@torch.no_grad()

View File

@@ -32,11 +32,11 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
logger = logging.get_logger(__name__)
logger = get_logger(__name__)
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):

View File

@@ -2,11 +2,11 @@
monkeypatch for Trainer _get_learning_rate method
"""
import logging
import torch
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release

View File

@@ -3,13 +3,13 @@ allow adding additional kwargs to Accelerator init
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
# create accelerator object

View File

@@ -3,13 +3,13 @@ fix for FSDP2 evals when using torch.compile
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """
model.eval()

View File

@@ -3,13 +3,13 @@ fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """

View File

@@ -2,13 +2,14 @@
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
logger = get_logger(__name__)
def fixed_fa_peft_integration_check(

View File

@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states)
@@ -133,7 +133,7 @@ def patch_self_attn_lora():
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True)
LOG.info("patching unsloth attn lora")
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
@@ -153,7 +153,7 @@ def integrate_rope_embeddings():
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
LOG.info("patching unsloth RoPE embeddings")
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
@@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
def patch_unsloth_layernorm():

View File

@@ -1,6 +1,5 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
import logging
from copy import deepcopy
from typing import Optional
@@ -10,7 +9,9 @@ from torch import Tensor
from transformers import ProcessorMixin
from transformers.image_utils import load_image
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class ProcessingStrategy:

View File

@@ -2,11 +2,11 @@
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.prompt_strategies")
LOG = get_logger(__name__)
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):

View File

@@ -3,9 +3,10 @@ module for base dataset transform strategies
"""
import importlib
import logging
LOG = logging.getLogger("axolotl")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load(strategy, cfg, module_base=None, **kwargs):

View File

@@ -2,11 +2,11 @@
import importlib
import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
LOG = get_logger(__name__)
def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,7 +2,6 @@
Bradley-Terry model with chat template prompt strategy.
"""
import logging
from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import (
@@ -10,10 +9,11 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy,
)
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
LOG = get_logger(__name__)
LOG.setLevel("INFO")
class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -44,7 +44,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}"
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}"
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -2,7 +2,6 @@
HF Chat Templates prompt strategy
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Set, Union
@@ -13,11 +12,12 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
LOG = get_logger(__name__)
LOG.setLevel("INFO")
class ChatTemplatePrompter(Prompter):
@@ -378,7 +378,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_res = self.prompter.build_prompt(
turns, images=images
) # type: ignore
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
@@ -555,8 +557,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and turns[0].get("role") == "system"
and (
"mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer
or "gemma" in self.tokenizer.name_or_path.lower()
or "gemma"
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
)
):
return -1, -1

View File

@@ -24,12 +24,14 @@ For a custom system message, the first "from" can be "system" (followed by alter
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
"""
import logging
from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass
@@ -129,7 +131,7 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
if cur_len < self.sequence_len:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
logging.warning(
LOG.warning(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)

View File

@@ -2,9 +2,10 @@
import importlib
import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load(tokenizer, cfg, ds_cfg, processor=None):

View File

@@ -1,12 +1,12 @@
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100

View File

@@ -1,7 +1,6 @@
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
import copy
import logging
from collections import defaultdict
from typing import Generator, List, Tuple
@@ -10,8 +9,9 @@ from axolotl.prompt_tokenizers import (
parse_tokenized_to_result,
tokenize_prompt_default,
)
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100

View File

@@ -1,14 +1,14 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc
import logging
from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec

View File

@@ -1,12 +1,13 @@
"""Module containing prompters"""
import logging
from enum import Enum
from typing import Generator, Optional, Union
from colorama import Fore
LOG = logging.getLogger("axolotl")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"

View File

@@ -2,7 +2,6 @@
import importlib
import inspect
import logging
import os
import signal
import sys
@@ -37,6 +36,7 @@ from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContext
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer
@@ -45,7 +45,7 @@ try:
except ImportError:
BetterTransformer = None
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def setup_model_and_tokenizer(
@@ -64,9 +64,7 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`).
"""
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
)
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import gc
import json
import logging
import os
import traceback
from shutil import copyfile
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
is_main_process,
zero_first,
)
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING:
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
LOG = get_logger(__name__)
class EvalFirstStepCallback(
@@ -753,7 +753,14 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
].append(pred_step_text)
row_index += 1
if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
# type: ignore[attr-defined]
wandb.run.log(
{
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
table_data
)
}
)
elif logger == "mlflow" and is_mlflow_available():
import mlflow

View File

@@ -1,17 +1,17 @@
"""Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING
import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
LOG = get_logger(__name__)
class SaveAxolotlConfigtoCometCallback(TrainerCallback):

View File

@@ -6,17 +6,18 @@ Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0
"""
import logging
from functools import reduce
from typing import TYPE_CHECKING
import numpy as np
from transformers import TrainerCallback
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainer
LOG = logging.getLogger("axolotl.callbacks.lisa")
LOG = get_logger(__name__)
def lisa_callback_factory(trainer: "AxolotlTrainer"):

View File

@@ -1,6 +1,5 @@
"""MLFlow module for trainer callbacks"""
import logging
import os
from shutil import copyfile
from tempfile import NamedTemporaryFile
@@ -10,11 +9,12 @@ import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
LOG = get_logger(__name__)
def should_log_artifacts() -> bool:

View File

@@ -1,6 +1,5 @@
"""QAT Callback for HF Causal Trainer"""
import logging
from functools import partial
from torch import nn
@@ -8,9 +7,10 @@ from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear
from transformers import TrainerCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.quantization import QATConfig
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def toggle_fake_quant(mod: nn.Module, enable: bool):

File diff suppressed because one or more lines are too long

View File

@@ -1,11 +1,11 @@
"""Module for wandb utilities"""
import logging
import os
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.utils.comet_")
LOG = get_logger(__name__)
COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE",

View File

@@ -1,7 +1,6 @@
"""Module for working with config dicts"""
import json
import logging
import os
from typing import Optional
@@ -15,13 +14,14 @@ from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.utils import load_model_config
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__, use_environ=True)
def choose_device(cfg):

View File

@@ -1,7 +1,6 @@
"""data handling specific to pretraining"""
import functools
import logging
from collections import defaultdict
from typing import Callable, Dict, List, Optional
@@ -11,10 +10,11 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = logging.getLogger("axolotl")
LOG = get_logger(__name__)
def encode_pretraining(

View File

@@ -1,7 +1,6 @@
"""data handling specific to DPO"""
import inspect
import logging
from functools import partial
from pathlib import Path
from typing import Any, List, Union
@@ -18,9 +17,10 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def _get_path(ds_hash, cfg):
@@ -217,7 +217,7 @@ def load_prepare_preference_datasets(cfg):
+ "|"
+ "train"
+ "|"
+ str(seed)
+ str(cfg.seed or 42)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
@@ -226,7 +226,7 @@ def load_prepare_preference_datasets(cfg):
+ "|"
+ "test"
+ "|"
+ str(seed)
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

View File

@@ -1,7 +1,6 @@
"""data handling specific to SFT"""
import functools
import logging
import os
import tempfile
from pathlib import Path
@@ -54,12 +53,13 @@ from axolotl.utils.data.utils import (
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
)
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5)
@@ -182,10 +182,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters
@@ -331,12 +330,12 @@ def load_tokenized_prepared_datasets(
if len(datasets) == 1:
dataset = datasets[0]
else:
LOG.info("merging datasets")
LOG.info("Merging datasets...")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
if cfg.shuffle_merged_datasets:
LOG.debug("shuffle merged datasets")
LOG.debug("Shuffling merged datasets...")
dataset = dataset.shuffle(seed=seed)
else:
LOG.debug("NOT shuffling merged datasets")
@@ -426,7 +425,7 @@ def load_prepare_datasets(
+ "|"
+ "train"
+ "|"
+ str(seed)
+ str(cfg.seed or 42)
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
@@ -435,7 +434,7 @@ def load_prepare_datasets(
+ "|"
+ "test"
+ "|"
+ str(seed)
+ str(cfg.seed or 42)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

View File

@@ -2,7 +2,6 @@
import functools
import hashlib
import logging
import time
from enum import Enum
@@ -12,10 +11,11 @@ import requests
from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class RetryStrategy(Enum):

View File

@@ -0,0 +1,62 @@
"""
logging helpers to only log on main process
"""
import functools
import logging
import os
from axolotl.utils.distributed import is_main_process
# Adapted from Accelerate
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py
class MultiProcessAdapter(logging.LoggerAdapter):
"""
logger adapter for distributed logging, specifically to only log on main process
"""
def __init__(self, logger, use_environ=False, extra=None):
super().__init__(logger, extra)
self.use_environ = use_environ
@staticmethod
def _should_log(main_process_only, use_environ=False):
return not main_process_only or (
main_process_only and is_main_process(use_environ=use_environ)
)
def log(self, level, msg, *args, **kwargs):
use_environ = kwargs.pop("use_environ", self.use_environ)
main_process_only = kwargs.pop("main_process_only", True)
kwargs.setdefault("stacklevel", 2)
if self.isEnabledFor(level) and self._should_log(
main_process_only, use_environ=use_environ
):
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
@functools.lru_cache(maxsize=10)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
switch to another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
def get_logger(
name: str, log_level: str | None = None, use_environ: bool = False
) -> MultiProcessAdapter:
if log_level is None:
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
logger = logging.getLogger(name)
if log_level is not None:
logger.setLevel(log_level.upper())
logger.root.setLevel(log_level.upper())
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})

View File

@@ -2,8 +2,6 @@
Utilities for quantization including QAT and PTQ using torchao.
"""
import logging
import torch
from torch import nn
from torchao.core.config import AOBaseConfig
@@ -25,8 +23,6 @@ from torchao.quantization.quant_api import (
from axolotl.utils.schemas.enums import TorchIntDType
LOG = logging.getLogger(__name__)
def get_ptq_config(
weight_dtype: TorchIntDType,

View File

@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context
@@ -14,9 +13,9 @@ import numpy as np
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
LOG = get_logger(__name__)
@numba.njit

View File

@@ -2,7 +2,6 @@
# pylint: disable=too-many-lines
import logging
import os
from typing import Annotated, Any, Literal
@@ -18,6 +17,7 @@ from pydantic import (
)
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
@@ -49,7 +49,7 @@ from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__, use_environ=True)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}

View File

@@ -1,11 +1,12 @@
"""Pydantic models for deprecated and remapped configuration parameters"""
import logging
from typing import Any
from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class DeprecatedParameters(BaseModel):

View File

@@ -64,6 +64,7 @@ class ChatTemplate(str, Enum):
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
aya = "aya" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""

View File

@@ -1,11 +1,12 @@
"""Pydantic models for Axolotl integrations"""
import logging
from typing import Any
from pydantic import BaseModel, Field, model_validator
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class MLFlowConfig(BaseModel):

View File

@@ -1,10 +1,10 @@
"""Pydantic models for model input / output, etc. configuration"""
import logging
from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__)
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
class ModelInputConfig(BaseModel):

View File

@@ -1,15 +1,15 @@
"""Pydantic models for training hyperparameters"""
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
class LrGroup(BaseModel):

View File

@@ -1,8 +1,8 @@
"""Utilities for Axolotl Pydantic models"""
import logging
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
def handle_legacy_message_fields_logic(data: dict) -> dict:

View File

@@ -1,10 +1,10 @@
"""Module for tokenization utilities"""
import logging
from termcolor import colored
LOG = logging.getLogger("axolotl")
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_dataset_labels(

View File

@@ -22,7 +22,7 @@ from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
LOG = get_logger(__name__)
@torch.jit.script
@@ -402,7 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(len)
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
LOG.debug(f"total_num_tokens: {total_num_tokens:_}")
if update:
cfg.total_num_tokens = total_num_tokens
@@ -420,10 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
main_process_only=True,
)
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`")
if update:
cfg.total_supervised_tokens = total_supervised_tokens
@@ -448,8 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.sequence_parallel_degree
)
LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
main_process_only=True,
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
)
else:
if cfg.flash_attention and not cfg.multipack_real_batches:
@@ -478,7 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
batch_sampler=sampler,
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
total_num_steps = int(
@@ -500,10 +496,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
if update:
cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
else:
total_num_steps = int(
math.ceil(
@@ -513,7 +506,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
/ cfg.batch_size
)
)
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
LOG.debug(f"total_num_steps: {total_num_steps}")
return total_num_steps

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -14,10 +13,11 @@ from transformers.testing_utils import get_torch_dist_unique_port
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu eval
"""
import logging
import os
from pathlib import Path
@@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -13,10 +12,11 @@ from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama
"""
import logging
import os
from pathlib import Path
@@ -15,10 +14,11 @@ from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu qwen2
"""
import logging
import os
from pathlib import Path
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu post-training use Ray Train
"""
import logging
import os
from pathlib import Path
@@ -11,10 +10,11 @@ import yaml
from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__)
LOG = get_logger(__name__)
os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import pytest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for falcon
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama w/ S2 attn
"""
import logging
import os
import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama
"""
import logging
import os
import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"

Some files were not shown because too many files have changed in this diff Show More