Rank 0-only logging (#2608)

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-05-28 14:57:30 +01:00
committed by GitHub
parent 5fca214108
commit 65c5481120
135 changed files with 454 additions and 378 deletions

View File

@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1 val_set_size: 0.1
output_dir: ./outputs/lora-out output_dir: ./outputs/lora-out
@@ -38,6 +38,7 @@ wandb_log_model:
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_8bit optimizer: adamw_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002

View File

@@ -1,6 +1,5 @@
"""Various checks for Axolotl CLI.""" """Various checks for Axolotl CLI."""
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_accelerate_default_config() -> None: def check_accelerate_default_config() -> None:

View File

@@ -1,7 +1,6 @@
"""Configuration loading and processing.""" """Configuration loading and processing."""
import json import json
import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
validate_config, validate_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__) LOG = get_logger(__name__, use_environ=True)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]: def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
) )
if len(yaml_files) == 1: if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'") LOG.info(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0]) return str(yaml_files[0])
print("Choose a YAML file:") LOG.info("Choose a YAML file:")
for idx, file in enumerate(yaml_files): for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}") LOG.info(f"{idx + 1}. {file}")
chosen_file = None chosen_file = None
while chosen_file is None: while chosen_file is None:
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
if 1 <= choice <= len(yaml_files): if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1]) chosen_file = str(yaml_files[choice - 1])
else: else:
print("Invalid choice. Please choose a number from the list.") LOG.info("Invalid choice. Please choose a number from the list.")
except ValueError: except ValueError:
print("Invalid input. Please enter a number.") LOG.info("Invalid input. Please enter a number.")
return chosen_file return chosen_file

View File

@@ -1,6 +1,5 @@
"""CLI to run evaluation on a model.""" """CLI to run evaluation on a model."""
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to run inference on a trained model.""" """CLI to run inference on a trained model."""
import importlib import importlib
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
get_chat_template_from_config, get_chat_template_from_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def get_multi_line_input() -> str: def get_multi_line_input() -> str:

View File

@@ -2,7 +2,6 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import logging
import os import os
import subprocess # nosec B404 import subprocess # nosec B404
import tempfile import tempfile
@@ -31,8 +30,11 @@ from axolotl.cli.utils import (
) )
from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env from axolotl.utils import patch_optimized_env
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
@click.group() @click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl") @click.version_option(version=axolotl.__version__, prog_name="axolotl")
@@ -177,7 +179,7 @@ def train(
do_cli(config=cfg_file, **kwargs) do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}") LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep: if not sweep:
raise exc raise exc

View File

@@ -1,6 +1,5 @@
"""CLI to merge a trained LoRA into a base model.""" """CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -13,8 +12,9 @@ from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None: def do_merge_lora(*, cfg: DictDefault) -> None:

View File

@@ -1,7 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" """CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json import json
import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -27,8 +26,9 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):

View File

@@ -1,6 +1,5 @@
"""CLI to run preprocessing of a dataset.""" """CLI to run preprocessing of a dataset."""
import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:

View File

@@ -2,7 +2,6 @@
CLI to post-training quantize a model using torchao CLI to post-training quantize a model using torchao
""" """
import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -11,9 +10,10 @@ from transformers import AutoModelForCausalLM
from axolotl.cli.art import print_axolotl_text_art from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def do_quantize( def do_quantize(

View File

@@ -1,7 +1,6 @@
"""CLI to run training on a model.""" """CLI to run training on a model."""
import gc import gc
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
from axolotl.utils.config import normalize_config, resolve_dtype from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs): def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
""" """

View File

@@ -4,7 +4,6 @@ import concurrent.futures
import dataclasses import dataclasses
import hashlib import hashlib
import json import json
import logging
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
@@ -23,8 +22,9 @@ from transformers import (
from axolotl.loaders import load_processor, load_tokenizer from axolotl.loaders import load_processor, load_tokenizer
from axolotl.loaders.model import ModelLoader from axolotl.loaders.model import ModelLoader
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def strip_optional_type(field_type: type | str | None): def strip_optional_type(field_type: type | str | None):

View File

@@ -1,6 +1,5 @@
"""Dataset loading utilities.""" """Dataset loading utilities."""
import logging
import math import math
import random import random
from dataclasses import dataclass from dataclasses import dataclass
@@ -14,10 +13,11 @@ from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
@dataclass @dataclass

View File

@@ -156,7 +156,6 @@ class Messages(BaseModel):
len(input_ids) : len(input_ids) + len(pending_input_ids) len(input_ids) : len(input_ids) + len(pending_input_ids)
] ]
if new_pending_inputs != pending_input_ids: if new_pending_inputs != pending_input_ids:
# logging.warning("tokenization mismatch from concatenation.")
pending_input_ids = new_pending_inputs pending_input_ids = new_pending_inputs
input_ids.extend(pending_input_ids) input_ids.extend(pending_input_ids)
if pending_weight: if pending_weight:

View File

@@ -19,7 +19,6 @@ import abc
import importlib import importlib
import importlib.util import importlib.util
import inspect import inspect
import logging
import math import math
import os import os
import sys import sys
@@ -88,6 +87,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq,
) )
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
try: try:
@@ -95,7 +95,7 @@ try:
except ImportError: except ImportError:
pass pass
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):

View File

@@ -4,7 +4,6 @@
from __future__ import annotations from __future__ import annotations
import logging
import os import os
from collections import defaultdict from collections import defaultdict
from functools import wraps from functools import wraps
@@ -34,9 +33,10 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
) )
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer): class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):

View File

@@ -2,7 +2,6 @@
import importlib import importlib
import inspect import inspect
import logging
from typing import Any from typing import Any
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
AxolotlGRPOTrainer, AxolotlGRPOTrainer,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class GRPOStrategy: class GRPOStrategy:

View File

@@ -1,18 +1,17 @@
"""Module for Axolotl trainer optimizer mixin""" """Module for Axolotl trainer optimizer mixin"""
import logging
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from transformers.trainer import Trainer from transformers.trainer import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from axolotl.integrations.base import BaseOptimizerFactory from axolotl.integrations.base import BaseOptimizerFactory
from axolotl.utils.logging import get_logger
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class OptimizerMixin(Trainer): class OptimizerMixin(Trainer):

View File

@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release TODO: Remove when upstream added PR to release
""" """
import logging
import os import os
import random import random
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class RngLoaderMixin(Trainer): class RngLoaderMixin(Trainer):

View File

@@ -1,12 +1,11 @@
"""Module for Axolotl trainer scheduler mixin""" """Module for Axolotl trainer scheduler mixin"""
import logging
import torch import torch
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
from transformers.trainer import Trainer from transformers.trainer import Trainer
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.utils.logging import get_logger
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import (
RexLR, RexLR,
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant, get_cosine_schedule_with_warmup_decay_constant,
) )
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class SchedulerMixin(Trainer): class SchedulerMixin(Trainer):
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
self.lr_scheduler = RexLR( self.lr_scheduler = RexLR(
optimizer=optimizer, optimizer=optimizer,
max_lr=self.args.learning_rate, max_lr=self.args.learning_rate,
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio), min_lr=0 if not use_cosine_min_lr else (
self.args.learning_rate * self.args.cosine_min_lr_ratio),
total_steps=num_training_steps, total_steps=num_training_steps,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
) )
elif use_cosine_quadratic: elif use_cosine_quadratic:
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") LOG.warning(
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer, optimizer,
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
return super().create_scheduler(num_training_steps, optimizer=optimizer) return super().create_scheduler(num_training_steps, optimizer=optimizer)
else: else:
if use_cosine_quadratic: if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") LOG.warning(
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr: if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") LOG.warning(
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler # type: ignore return self.lr_scheduler # type: ignore

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality""" """Module containing Dataset functionality"""
import logging
import os import os
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded # We want this to be a wrapper for an existing dataset that we have loaded
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use # let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets # the collators later on to pad the datasets
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
class TokenizedPromptDataset(Dataset): class TokenizedPromptDataset(Dataset):

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections import collections
import importlib import importlib
import logging
from typing import TYPE_CHECKING, Callable, OrderedDict, Union from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel from peft import PeftModel
@@ -31,6 +30,9 @@ from torch.optim.lr_scheduler import LRScheduler
from transformers import PreTrainedModel, Trainer from transformers import PreTrainedModel, Trainer
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.datasets import TrainDatasetMeta
@@ -331,12 +333,12 @@ class PluginManager:
ImportError: If the plugin module cannot be imported. ImportError: If the plugin module cannot be imported.
""" """
try: try:
logging.info(f"Attempting to load plugin: {plugin_name}") LOG.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}") LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") LOG.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self) -> list[str]: def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.' """Returns a list of Pydantic classes for all registered plugins' input arguments.'

View File

@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team. from Apple's ML team.
""" """
import importlib import importlib
import logging
import torch import torch
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401 from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy") LOG = get_logger(__name__, use_environ=True)
_CCE_INSTALL_MESSAGE = ( _CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using " "Please install cut_cross_entropy with transformers support using "
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
cce_patch, cce_patch,
) )
if is_main_process(use_environ=True): LOG.info(
LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" )
)
# The patch checks model_type internally # The patch checks model_type internally
cce_patch(cfg.model_config_type) cce_patch(cfg.model_config_type)

View File

@@ -15,12 +15,13 @@
""" """
Module for handling Cut Cross Entropy input arguments. Module for handling Cut Cross Entropy input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class CutCrossEntropyArgs(BaseModel): class CutCrossEntropyArgs(BaseModel):

View File

@@ -2,15 +2,15 @@
Grokfast plugin for Axolotl Grokfast plugin for Axolotl
""" """
import logging
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from axolotl.utils.logging import get_logger
from ..base import BasePlugin from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast") LOG = get_logger(__name__)
class GrokfastCallbackHandler(TrainerCallback): class GrokfastCallbackHandler(TrainerCallback):

View File

@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight. It is designed to be performant, correct, and light-weight.
""" """
import inspect import inspect
import logging
import sys import sys
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.distributed import is_main_process from axolotl.utils.logging import get_logger
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
from .utils import patch_with_compile_disable from .utils import patch_with_compile_disable
LOG = logging.getLogger("axolotl.integrations.liger") LOG = get_logger(__name__, use_environ=True)
class LigerPlugin(BasePlugin): class LigerPlugin(BasePlugin):
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters: elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation kwargs["swiglu"] = cfg.liger_glu_activation
if is_main_process(use_environ=True): LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
apply_liger_fn(**kwargs) apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba": elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba from transformers.models.jamba import modeling_jamba
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope: if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA. # The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528 # See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.") LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
logging.warning("liger_glu_activation is not supported for DeepseekV2.") LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
if cfg.liger_rms_norm: if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation: if cfg.liger_glu_activation:
@@ -186,6 +182,6 @@ class LigerPlugin(BasePlugin):
swiglu=cfg.liger_glu_activation, swiglu=cfg.liger_glu_activation,
) )
else: else:
logging.warning( LOG.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
) )

View File

@@ -15,12 +15,13 @@
""" """
Module for handling LIGER input arguments. Module for handling LIGER input arguments.
""" """
import logging
from typing import Optional from typing import Optional
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class LigerArgs(BaseModel): class LigerArgs(BaseModel):

View File

@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
by maintaining masks for zero weights during training. by maintaining masks for zero weights during training.
""" """
import logging
from functools import wraps from functools import wraps
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
from transformers.training_args import TrainingArguments from transformers.training_args import TrainingArguments
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
P = ParamSpec("P") # Params for generic function signatures P = ParamSpec("P") # Params for generic function signatures
R = TypeVar("R") # Return type for generic function signatures R = TypeVar("R") # Return type for generic function signatures
LOG = logging.getLogger("axolotl.integrations.llm_compressor") LOG = get_logger(__name__)
class LLMCompressorCallbackHandler(TrainerCallback): class LLMCompressorCallbackHandler(TrainerCallback):

View File

@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
""" """
import json import json
import logging
import requests import requests
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.utils.logging import get_logger
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401 from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
LOG = get_logger(__name__)
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5): def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {} unfrozen_parameters = {}
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
except FileNotFoundError: except FileNotFoundError:
pass pass
except Exception as exc: # pylint: disable=broad-exception-caught except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}") LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data: if not snr_data:
try: try:
snr_data = requests.get(snr_url, timeout=60).json() snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc: except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}") LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return return
# also catch json parsing errors # also catch json parsing errors
except json.JSONDecodeError as exc: except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}") LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return return
unfrozen_parameters = _generate_unfrozen_params_yaml( unfrozen_parameters = _generate_unfrozen_params_yaml(

View File

@@ -1,6 +1,5 @@
"""Adapter loading functionality, including LoRA / QLoRA and associated utils""" """Adapter loading functionality, including LoRA / QLoRA and associated utils"""
import logging
import os import os
import types import types
from typing import Any from typing import Any
@@ -21,8 +20,9 @@ from transformers import PreTrainedModel
from axolotl.loaders.utils import get_linear_embedding_layers from axolotl.loaders.utils import get_linear_embedding_layers
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def setup_quantized_meta_for_peft(model: torch.nn.Module): def setup_quantized_meta_for_peft(model: torch.nn.Module):

View File

@@ -3,7 +3,6 @@ models.
""" """
import gc import gc
import logging
import math import math
import os import os
from functools import cached_property from functools import cached_property
@@ -47,10 +46,11 @@ from axolotl.utils.distributed import (
get_device_count, get_device_count,
get_device_type, get_device_type,
) )
from axolotl.utils.logging import get_logger
from axolotl.utils.model_shard_quant import load_sharded_model_quant from axolotl.utils.model_shard_quant import load_sharded_model_quant
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -4,7 +4,6 @@ Applies pre- and post-model load patches for various fixes and optimizations.
""" """
import importlib.util import importlib.util
import logging
from functools import cached_property from functools import cached_property
import addict import addict
@@ -17,8 +16,9 @@ from axolotl.monkeypatch.multipack import (
patch_for_multipack, patch_for_multipack,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -1,6 +1,5 @@
"""Processor loading functionality for multi-modal models""" """Processor loading functionality for multi-modal models"""
import logging
from typing import Any from typing import Any
import transformers import transformers
@@ -10,8 +9,9 @@ from transformers import (
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase): def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):

View File

@@ -1,7 +1,6 @@
"""Tokenizer loading functionality and associated utils""" """Tokenizer loading functionality and associated utils"""
import json import json
import logging
import os import os
import transformers import transformers
@@ -19,8 +18,9 @@ from axolotl.utils.distributed import (
is_local_main_process, is_local_main_process,
is_main_process, is_main_process,
) )
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
PLUGIN_MANAGER = PluginManager.get_instance() PLUGIN_MANAGER = PluginManager.get_instance()

View File

@@ -1,7 +1,6 @@
"""Utilities for axolotl.loaders module""" """Utilities for axolotl.loaders module"""
import contextlib import contextlib
import logging
from typing import Type from typing import Type
import addict import addict
@@ -9,8 +8,9 @@ import torch
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def get_module_class_from_name( def get_module_class_from_name(

View File

@@ -2,12 +2,13 @@
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
""" """
import logging
import sys import sys
import torch import torch
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict): def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):

View File

@@ -3,7 +3,6 @@ Flash attention monkey patch for cerebras btlm model
""" """
import importlib import importlib
import logging
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
@@ -11,7 +10,9 @@ from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"): def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):

View File

@@ -18,7 +18,6 @@ DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
import atexit import atexit
import concurrent.futures import concurrent.futures
import logging
import os import os
import queue import queue
import shutil import shutil
@@ -32,11 +31,13 @@ from typing import Dict
import torch import torch
from axolotl.utils.logging import get_logger
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda") torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda") torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
# Setup logger # Setup logger
logger = logging.getLogger(__name__) logger = get_logger(__name__)
class DiskOffloadManager: class DiskOffloadManager:

View File

@@ -2,7 +2,6 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -25,6 +24,7 @@ from transformers.models.llama.modeling_llama import (
) )
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from axolotl.utils.logging import get_logger
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -41,7 +41,7 @@ except ImportError:
) )
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
def is_xformers_available() -> bool: def is_xformers_available() -> bool:
@@ -612,9 +612,10 @@ def generate_qkv(
q, query_padding_mask q, query_padding_mask
) )
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, indices_q, batch_size, seqlen_q return pad_input( # noqa: E731
) output_unpad, indices_q, batch_size, seqlen_q
)
else: else:
q_unpad = rearrange(q, "b s h d -> (b s) h d") q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -627,9 +628,10 @@ def generate_qkv(
) )
max_seqlen_q = seqlen_q max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, "(b s) h d -> b s h d", b=batch_size return rearrange( # noqa: E731
) output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None: if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)

View File

@@ -2,7 +2,6 @@
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
""" """
import logging
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -11,10 +10,14 @@ import torch.nn.functional as F
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
try: try:
import xformers.ops import xformers.ops
except ImportError: except ImportError:
logging.error("xformers not found! Please install it before trying to use it.") LOG.error("xformers not found! Please install it before trying to use it.")
def hijack_llama_attention(): def hijack_llama_attention():

View File

@@ -7,7 +7,6 @@ import types
from typing import Generator, Tuple, Type from typing import Generator, Tuple, Type
import torch import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers import AutoConfig from transformers import AutoConfig
@@ -20,6 +19,7 @@ from axolotl.kernels.lora import (
) )
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__) LOG = get_logger(__name__)

View File

@@ -2,7 +2,6 @@
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging
from functools import partial from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -28,8 +27,9 @@ from transformers.models.mistral.modeling_mistral import (
) )
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.monkeypatch.mistral") LOG = get_logger(__name__)
def replace_mistral_attn_with_flash_attn( def replace_mistral_attn_with_flash_attn(
@@ -359,9 +359,10 @@ def generate_qkv(
q, query_padding_mask q, query_padding_mask
) )
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, indices_q, batch_size, seqlen_q return pad_input( # noqa: E731
) output_unpad, indices_q, batch_size, seqlen_q
)
else: else:
q_unpad = rearrange(q, "b s h d -> (b s) h d") q_unpad = rearrange(q, "b s h d -> (b s) h d")
@@ -374,9 +375,10 @@ def generate_qkv(
) )
max_seqlen_q = seqlen_q max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731 def output_pad_fn(output_unpad):
output_unpad, "(b s) h d -> b s h d", b=batch_size return rearrange( # noqa: E731
) output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None: if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)

View File

@@ -3,14 +3,14 @@ Patch prepare_model_for_kbit_training to not upcast everything
""" """
import inspect import inspect
import logging
import peft import peft
import axolotl import axolotl
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
ORIGINAL_PREPARE_CODE = """ ORIGINAL_PREPARE_CODE = """
for param in model.parameters(): for param in model.parameters():

View File

@@ -2,7 +2,6 @@
import glob import glob
import json import json
import logging
import os.path import os.path
import shutil import shutil
from functools import partial from functools import partial
@@ -27,8 +26,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.relora") LOG = get_logger(__name__)
@torch.no_grad() @torch.no_grad()

View File

@@ -32,11 +32,11 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
from torch import nn from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
logger = logging.get_logger(__name__) logger = get_logger(__name__)
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"): def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):

View File

@@ -2,11 +2,11 @@
monkeypatch for Trainer _get_learning_rate method monkeypatch for Trainer _get_learning_rate method
""" """
import logging
import torch import torch
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release # TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release

View File

@@ -3,13 +3,13 @@ allow adding additional kwargs to Accelerator init
""" """
import inspect import inspect
import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """ ORIGINAL_TRAINER_CODE = """
# create accelerator object # create accelerator object

View File

@@ -3,13 +3,13 @@ fix for FSDP2 evals when using torch.compile
""" """
import inspect import inspect
import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """ ORIGINAL_TRAINER_CODE = """
model.eval() model.eval()

View File

@@ -3,13 +3,13 @@ fix for FSDP optimizer save in trainer w 4.47.0
""" """
import inspect import inspect
import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") LOG = get_logger(__name__)
ORIGINAL_TRAINER_CODE = """ ORIGINAL_TRAINER_CODE = """

View File

@@ -2,13 +2,14 @@
see https://github.com/huggingface/transformers/pull/35834 see https://github.com/huggingface/transformers/pull/35834
""" """
import logging
from functools import partial from functools import partial
from typing import Optional from typing import Optional
import torch import torch
logger = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
logger = get_logger(__name__)
def fixed_fa_peft_integration_check( def fixed_fa_peft_integration_check(

View File

@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
@@ -133,7 +133,7 @@ def patch_self_attn_lora():
) )
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
self_attn_lora_patched = True self_attn_lora_patched = True
LOG.info("patching unsloth attn lora", main_process_only=True) LOG.info("patching unsloth attn lora")
LlamaFlashAttention2.forward = ( LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
) )
@@ -153,7 +153,7 @@ def integrate_rope_embeddings():
): ):
return fast_rope_embedding(q, k, cos, sin) return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings", main_process_only=True) LOG.info("patching unsloth RoPE embeddings")
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
@@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora: if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else: else:
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx) LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv layer.self_attn.apply_qkv = apply_lora_qkv
else: else:
layer.self_attn.apply_qkv = original_apply_qkv layer.self_attn.apply_qkv = original_apply_qkv
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx) LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
if cfg.unsloth_lora_o: if cfg.unsloth_lora_o:
layer_modules = [ layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o layer.self_attn.apply_o = apply_lora_o
else: else:
layer.self_attn.apply_o = original_apply_o layer.self_attn.apply_o = original_apply_o
LOG.warning( LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
def patch_unsloth_layernorm(): def patch_unsloth_layernorm():

View File

@@ -1,6 +1,5 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
import logging
from copy import deepcopy from copy import deepcopy
from typing import Optional from typing import Optional
@@ -10,7 +9,9 @@ from torch import Tensor
from transformers import ProcessorMixin from transformers import ProcessorMixin
from transformers.image_utils import load_image from transformers.image_utils import load_image
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class ProcessingStrategy: class ProcessingStrategy:

View File

@@ -2,11 +2,11 @@
import importlib import importlib
import inspect import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.prompt_strategies") LOG = get_logger(__name__)
def load(strategy, tokenizer, cfg, ds_cfg, processor=None): def load(strategy, tokenizer, cfg, ds_cfg, processor=None):

View File

@@ -3,9 +3,10 @@ module for base dataset transform strategies
""" """
import importlib import importlib
import logging
LOG = logging.getLogger("axolotl") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load(strategy, cfg, module_base=None, **kwargs): def load(strategy, cfg, module_base=None, **kwargs):

View File

@@ -2,11 +2,11 @@
import importlib import importlib
import inspect import inspect
import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry") LOG = get_logger(__name__)
def load(strategy, tokenizer, cfg, ds_cfg): def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,7 +2,6 @@
Bradley-Terry model with chat template prompt strategy. Bradley-Terry model with chat template prompt strategy.
""" """
import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import ( from axolotl.prompt_strategies.chat_template import (
@@ -10,10 +9,11 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy, ChatTemplateStrategy,
) )
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template") LOG = get_logger(__name__)
LOG.setLevel(logging.INFO) LOG.setLevel("INFO")
class BTChatTemplateStrategy(ChatTemplateStrategy): class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -44,7 +44,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(chosen_tokenized["input_ids"]) > max_length: if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}", f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}"
) )
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length] chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
if len(rejected_tokenized["input_ids"]) > max_length: if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning( LOG.warning(
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}", f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}"
) )
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][ rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -2,7 +2,6 @@
HF Chat Templates prompt strategy HF Chat Templates prompt strategy
""" """
import logging
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Set, Union from typing import Any, Dict, List, Set, Union
@@ -13,11 +12,12 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import DatasetConfig from axolotl.utils.schemas.datasets import DatasetConfig
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
LOG.setLevel(logging.INFO) LOG.setLevel("INFO")
class ChatTemplatePrompter(Prompter): class ChatTemplatePrompter(Prompter):
@@ -378,7 +378,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
add_generation_prompt=True, add_generation_prompt=True,
images=images, images=images,
) )
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore tokenized_res = self.prompter.build_prompt(
turns, images=images
) # type: ignore
tokenized_prompt = {} tokenized_prompt = {}
if isinstance(tokenized_res, list): if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
@@ -555,8 +557,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
and turns[0].get("role") == "system" and turns[0].get("role") == "system"
and ( and (
"mistral" in self.tokenizer.name_or_path.lower() "mistral" in self.tokenizer.name_or_path.lower()
# gemma3 uses gemma tokenizer or "gemma"
or "gemma" in self.tokenizer.name_or_path.lower() in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
) )
): ):
return -1, -1 return -1, -1

View File

@@ -24,12 +24,14 @@ For a custom system message, the first "from" can be "system" (followed by alter
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
""" """
import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Generator, List, Sequence from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@dataclass @dataclass
@@ -129,7 +131,7 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
if cur_len < self.sequence_len: if cur_len < self.sequence_len:
if cur_len != total_len: if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID target[:] = IGNORE_TOKEN_ID
logging.warning( LOG.warning(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)" f" (ignored)"
) )

View File

@@ -2,9 +2,10 @@
import importlib import importlib
import inspect import inspect
import logging
LOG = logging.getLogger("axolotl.prompt_strategies.messages") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def load(tokenizer, cfg, ds_cfg, processor=None): def load(tokenizer, cfg, ds_cfg, processor=None):

View File

@@ -1,12 +1,12 @@
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class""" """Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
import logging
from typing import Tuple from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter from axolotl.prompters import AlpacaPrompter
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100

View File

@@ -1,7 +1,6 @@
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class""" """Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
import copy import copy
import logging
from collections import defaultdict from collections import defaultdict
from typing import Generator, List, Tuple from typing import Generator, List, Tuple
@@ -10,8 +9,9 @@ from axolotl.prompt_tokenizers import (
parse_tokenized_to_result, parse_tokenized_to_result,
tokenize_prompt_default, tokenize_prompt_default,
) )
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100

View File

@@ -1,14 +1,14 @@
"""Module containing PromptTokenizingStrategy and Prompter classes""" """Module containing PromptTokenizingStrategy and Prompter classes"""
import abc import abc
import logging
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec

View File

@@ -1,12 +1,13 @@
"""Module containing prompters""" """Module containing prompters"""
import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
from colorama import Fore from colorama import Fore
LOG = logging.getLogger("axolotl") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n" REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"

View File

@@ -2,7 +2,6 @@
import importlib import importlib
import inspect import inspect
import logging
import os import os
import signal import signal
import sys import sys
@@ -37,6 +36,7 @@ from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContext
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed from axolotl.utils.distributed import cleanup_distributed
from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -45,7 +45,7 @@ try:
except ImportError: except ImportError:
BetterTransformer = None BetterTransformer = None
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def setup_model_and_tokenizer( def setup_model_and_tokenizer(
@@ -64,9 +64,7 @@ def setup_model_and_tokenizer(
`None`), and processor (if multimodal, else `None`). `None`), and processor (if multimodal, else `None`).
""" """
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed # Load processor for multimodal models if needed

View File

@@ -4,7 +4,6 @@ from __future__ import annotations
import gc import gc
import json import json
import logging
import os import os
import traceback import traceback
from shutil import copyfile from shutil import copyfile
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
is_main_process, is_main_process,
zero_first, zero_first,
) )
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import AxolotlInputConfig from axolotl.utils.schemas.config import AxolotlInputConfig
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks") LOG = get_logger(__name__)
class EvalFirstStepCallback( class EvalFirstStepCallback(
@@ -753,7 +753,14 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
].append(pred_step_text) ].append(pred_step_text)
row_index += 1 row_index += 1
if logger == "wandb": if logger == "wandb":
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined] # type: ignore[attr-defined]
wandb.run.log(
{
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
table_data
)
}
)
elif logger == "mlflow" and is_mlflow_available(): elif logger == "mlflow" and is_mlflow_available():
import mlflow import mlflow

View File

@@ -1,17 +1,17 @@
"""Comet module for trainer callbacks""" """Comet module for trainer callbacks"""
import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import comet_ml import comet_ml
from transformers import TrainerCallback, TrainerControl, TrainerState from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = get_logger(__name__)
class SaveAxolotlConfigtoCometCallback(TrainerCallback): class SaveAxolotlConfigtoCometCallback(TrainerCallback):

View File

@@ -6,17 +6,18 @@ Arxiv: https://arxiv.org/abs/2403.17919
License: Apache 2.0 License: Apache 2.0
""" """
import logging
from functools import reduce from functools import reduce
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np import numpy as np
from transformers import TrainerCallback from transformers import TrainerCallback
from axolotl.utils.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainer from axolotl.core.trainer_builder import AxolotlTrainer
LOG = logging.getLogger("axolotl.callbacks.lisa") LOG = get_logger(__name__)
def lisa_callback_factory(trainer: "AxolotlTrainer"): def lisa_callback_factory(trainer: "AxolotlTrainer"):

View File

@@ -1,6 +1,5 @@
"""MLFlow module for trainer callbacks""" """MLFlow module for trainer callbacks"""
import logging
import os import os
from shutil import copyfile from shutil import copyfile
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
@@ -10,11 +9,12 @@ import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = get_logger(__name__)
def should_log_artifacts() -> bool: def should_log_artifacts() -> bool:

View File

@@ -1,6 +1,5 @@
"""QAT Callback for HF Causal Trainer""" """QAT Callback for HF Causal Trainer"""
import logging
from functools import partial from functools import partial
from torch import nn from torch import nn
@@ -8,9 +7,10 @@ from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
from torchao.quantization.qat.linear import FakeQuantizedLinear from torchao.quantization.qat.linear import FakeQuantizedLinear
from transformers import TrainerCallback from transformers import TrainerCallback
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.quantization import QATConfig from axolotl.utils.schemas.quantization import QATConfig
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def toggle_fake_quant(mod: nn.Module, enable: bool): def toggle_fake_quant(mod: nn.Module, enable: bool):

File diff suppressed because one or more lines are too long

View File

@@ -1,11 +1,11 @@
"""Module for wandb utilities""" """Module for wandb utilities"""
import logging
import os import os
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.utils.comet_") LOG = get_logger(__name__)
COMET_ENV_MAPPING_OVERRIDE = { COMET_ENV_MAPPING_OVERRIDE = {
"comet_mode": "COMET_START_MODE", "comet_mode": "COMET_START_MODE",

View File

@@ -1,7 +1,6 @@
"""Module for working with config dicts""" """Module for working with config dicts"""
import json import json
import logging
import os import os
from typing import Optional from typing import Optional
@@ -15,13 +14,14 @@ from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING
from axolotl.loaders.utils import load_model_config from axolotl.loaders.utils import load_model_config
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.config import ( from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
) )
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__, use_environ=True)
def choose_device(cfg): def choose_device(cfg):

View File

@@ -1,7 +1,6 @@
"""data handling specific to pretraining""" """data handling specific to pretraining"""
import functools import functools
import logging
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
@@ -11,10 +10,11 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = logging.getLogger("axolotl") LOG = get_logger(__name__)
def encode_pretraining( def encode_pretraining(

View File

@@ -1,7 +1,6 @@
"""data handling specific to DPO""" """data handling specific to DPO"""
import inspect import inspect
import logging
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, List, Union from typing import Any, List, Union
@@ -18,9 +17,10 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5 from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType from axolotl.utils.schemas.enums import RLType
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def _get_path(ds_hash, cfg): def _get_path(ds_hash, cfg):
@@ -217,7 +217,7 @@ def load_prepare_preference_datasets(cfg):
+ "|" + "|"
+ "train" + "train"
+ "|" + "|"
+ str(seed) + str(cfg.seed or 42)
) )
to_hash_test = ( to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access train_dataset._fingerprint # pylint: disable=protected-access
@@ -226,7 +226,7 @@ def load_prepare_preference_datasets(cfg):
+ "|" + "|"
+ "test" + "test"
+ "|" + "|"
+ str(seed) + str(cfg.seed or 42)
) )
train_fingerprint = md5(to_hash_train) train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test) test_fingerprint = md5(to_hash_test)

View File

@@ -1,7 +1,6 @@
"""data handling specific to SFT""" """data handling specific to SFT"""
import functools import functools
import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@@ -54,12 +53,13 @@ from axolotl.utils.data.utils import (
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.logging import get_logger
from axolotl.utils.trainer import ( from axolotl.utils.trainer import (
calculate_total_num_steps, calculate_total_num_steps,
process_datasets_for_packing, process_datasets_for_packing,
) )
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
@@ -182,10 +182,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
) )
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset) total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters return train_dataset, eval_dataset, total_num_steps, prompters
@@ -331,12 +330,12 @@ def load_tokenized_prepared_datasets(
if len(datasets) == 1: if len(datasets) == 1:
dataset = datasets[0] dataset = datasets[0]
else: else:
LOG.info("merging datasets") LOG.info("Merging datasets...")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
if len(datasets) > 1: if len(datasets) > 1:
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
LOG.debug("shuffle merged datasets") LOG.debug("Shuffling merged datasets...")
dataset = dataset.shuffle(seed=seed) dataset = dataset.shuffle(seed=seed)
else: else:
LOG.debug("NOT shuffling merged datasets") LOG.debug("NOT shuffling merged datasets")
@@ -426,7 +425,7 @@ def load_prepare_datasets(
+ "|" + "|"
+ "train" + "train"
+ "|" + "|"
+ str(seed) + str(cfg.seed or 42)
) )
to_hash_test = ( to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access dataset._fingerprint # pylint: disable=protected-access
@@ -435,7 +434,7 @@ def load_prepare_datasets(
+ "|" + "|"
+ "test" + "test"
+ "|" + "|"
+ str(seed) + str(cfg.seed or 42)
) )
train_fingerprint = md5(to_hash_train) train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test) test_fingerprint = md5(to_hash_test)

View File

@@ -2,7 +2,6 @@
import functools import functools
import hashlib import hashlib
import logging
import time import time
from enum import Enum from enum import Enum
@@ -12,10 +11,11 @@ import requests
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class RetryStrategy(Enum): class RetryStrategy(Enum):

View File

@@ -0,0 +1,62 @@
"""
logging helpers to only log on main process
"""
import functools
import logging
import os
from axolotl.utils.distributed import is_main_process
# Adapted from Accelerate
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py
class MultiProcessAdapter(logging.LoggerAdapter):
"""
logger adapter for distributed logging, specifically to only log on main process
"""
def __init__(self, logger, use_environ=False, extra=None):
super().__init__(logger, extra)
self.use_environ = use_environ
@staticmethod
def _should_log(main_process_only, use_environ=False):
return not main_process_only or (
main_process_only and is_main_process(use_environ=use_environ)
)
def log(self, level, msg, *args, **kwargs):
use_environ = kwargs.pop("use_environ", self.use_environ)
main_process_only = kwargs.pop("main_process_only", True)
kwargs.setdefault("stacklevel", 2)
if self.isEnabledFor(level) and self._should_log(
main_process_only, use_environ=use_environ
):
msg, kwargs = self.process(msg, kwargs)
self.logger.log(level, msg, *args, **kwargs)
@functools.lru_cache(maxsize=10)
def warning_once(self, *args, **kwargs):
"""
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
switch to another type of cache that includes the caller frame information in the hashing function.
"""
self.warning(*args, **kwargs)
def get_logger(
name: str, log_level: str | None = None, use_environ: bool = False
) -> MultiProcessAdapter:
if log_level is None:
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
logger = logging.getLogger(name)
if log_level is not None:
logger.setLevel(log_level.upper())
logger.root.setLevel(log_level.upper())
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})

View File

@@ -2,8 +2,6 @@
Utilities for quantization including QAT and PTQ using torchao. Utilities for quantization including QAT and PTQ using torchao.
""" """
import logging
import torch import torch
from torch import nn from torch import nn
from torchao.core.config import AOBaseConfig from torchao.core.config import AOBaseConfig
@@ -25,8 +23,6 @@ from torchao.quantization.quant_api import (
from axolotl.utils.schemas.enums import TorchIntDType from axolotl.utils.schemas.enums import TorchIntDType
LOG = logging.getLogger(__name__)
def get_ptq_config( def get_ptq_config(
weight_dtype: TorchIntDType, weight_dtype: TorchIntDType,

View File

@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
into fixed-capacity batches to optimize memory usage and training throughput. into fixed-capacity batches to optimize memory usage and training throughput.
""" """
import logging
import math import math
from concurrent.futures import ProcessPoolExecutor from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count, get_context from multiprocessing import cpu_count, get_context
@@ -14,9 +13,9 @@ import numpy as np
from torch.utils.data import BatchSampler, Sampler, SequentialSampler from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit @numba.njit

View File

@@ -2,7 +2,6 @@
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
import logging
import os import os
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
@@ -18,6 +17,7 @@ from pydantic import (
) )
from transformers.utils.import_utils import is_torch_npu_available from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import ( from axolotl.utils.schemas.datasets import (
DatasetConfig, DatasetConfig,
DPODataset, DPODataset,
@@ -49,7 +49,7 @@ from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.vllm import VllmConfig from axolotl.utils.schemas.vllm import VllmConfig
LOG = logging.getLogger(__name__) LOG = get_logger(__name__, use_environ=True)
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}

View File

@@ -1,11 +1,12 @@
"""Pydantic models for deprecated and remapped configuration parameters""" """Pydantic models for deprecated and remapped configuration parameters"""
import logging
from typing import Any from typing import Any
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):

View File

@@ -64,6 +64,7 @@ class ChatTemplate(str, Enum):
command_a_rag = "command_a_rag" # pylint: disable=invalid-name command_a_rag = "command_a_rag" # pylint: disable=invalid-name
aya = "aya" # pylint: disable=invalid-name aya = "aya" # pylint: disable=invalid-name
class CustomSupportedOptimizers(str, Enum): class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers""" """Custom supported optimizers"""

View File

@@ -1,11 +1,12 @@
"""Pydantic models for Axolotl integrations""" """Pydantic models for Axolotl integrations"""
import logging
from typing import Any from typing import Any
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class MLFlowConfig(BaseModel): class MLFlowConfig(BaseModel):

View File

@@ -1,10 +1,10 @@
"""Pydantic models for model input / output, etc. configuration""" """Pydantic models for model input / output, etc. configuration"""
import logging
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
LOG = logging.getLogger(__name__) from axolotl.utils.logging import get_logger
LOG = get_logger(__name__, use_environ=True)
class ModelInputConfig(BaseModel): class ModelInputConfig(BaseModel):

View File

@@ -1,15 +1,15 @@
"""Pydantic models for training hyperparameters""" """Pydantic models for training hyperparameters"""
import logging
from typing import Any, Literal from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from transformers import SchedulerType from transformers import SchedulerType
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import CustomSupportedOptimizers from axolotl.utils.schemas.enums import CustomSupportedOptimizers
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
class LrGroup(BaseModel): class LrGroup(BaseModel):

View File

@@ -1,8 +1,8 @@
"""Utilities for Axolotl Pydantic models""" """Utilities for Axolotl Pydantic models"""
import logging from axolotl.utils.logging import get_logger
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
def handle_legacy_message_fields_logic(data: dict) -> dict: def handle_legacy_message_fields_logic(data: dict) -> dict:

View File

@@ -1,10 +1,10 @@
"""Module for tokenization utilities""" """Module for tokenization utilities"""
import logging
from termcolor import colored from termcolor import colored
LOG = logging.getLogger("axolotl") from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def check_dataset_labels( def check_dataset_labels(

View File

@@ -22,7 +22,7 @@ from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl") LOG = get_logger(__name__)
@torch.jit.script @torch.jit.script
@@ -402,7 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(len) .apply(len)
.values .values
) )
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True) LOG.debug(f"total_num_tokens: {total_num_tokens:_}")
if update: if update:
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
@@ -420,10 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
.apply(lambda x: np.sum(np.array(x) != -100)) .apply(lambda x: np.sum(np.array(x) != -100))
.sum() .sum()
) )
LOG.debug( LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`")
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
main_process_only=True,
)
if update: if update:
cfg.total_supervised_tokens = total_supervised_tokens cfg.total_supervised_tokens = total_supervised_tokens
@@ -448,8 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
* cfg.sequence_parallel_degree * cfg.sequence_parallel_degree
) )
LOG.debug( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}", f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
main_process_only=True,
) )
else: else:
if cfg.flash_attention and not cfg.multipack_real_batches: if cfg.flash_attention and not cfg.multipack_real_batches:
@@ -478,7 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
batch_sampler=sampler, batch_sampler=sampler,
) )
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True) LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int( total_num_steps = int(
@@ -500,10 +496,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
) )
if update: if update:
cfg.sample_packing_eff_est = sample_packing_eff_est cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug( LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil( math.ceil(
@@ -513,7 +506,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
/ cfg.batch_size / cfg.batch_size
) )
) )
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True) LOG.debug(f"total_num_steps: {total_num_steps}")
return total_num_steps return total_num_steps

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama E2E tests for multigpu lora tinyllama
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -14,10 +13,11 @@ from transformers.testing_utils import get_torch_dist_unique_port
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu eval E2E tests for multigpu eval
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_tensorboard from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama E2E tests for multigpu lora tinyllama
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -13,10 +12,11 @@ from huggingface_hub import snapshot_download
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard from tests.e2e.utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu lora tinyllama E2E tests for multigpu lora tinyllama
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -15,10 +14,11 @@ from packaging import version
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu qwen2 E2E tests for multigpu qwen2
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
from transformers.testing_utils import get_torch_dist_unique_port from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
LOG = logging.getLogger("axolotl.tests.e2e.multigpu") LOG = get_logger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for multigpu post-training use Ray Train E2E tests for multigpu post-training use Ray Train
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -11,10 +10,11 @@ import yaml
from accelerate.test_utils import execute_subprocess_async from accelerate.test_utils import execute_subprocess_async
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0 from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
LOG = logging.getLogger(__name__) LOG = get_logger(__name__)
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent

View File

@@ -2,7 +2,6 @@
E2E tests for multipack fft llama using 4d attention masks E2E tests for multipack fft llama using 4d attention masks
""" """
import logging
import os import os
import unittest import unittest
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama E2E tests for lora llama
""" """
import logging
import os import os
import pytest import pytest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, check_tensorboard from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for falcon E2E tests for falcon
""" """
import logging
import os import os
import unittest import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama E2E tests for lora llama
""" """
import logging
import os import os
import unittest import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for llama w/ S2 attn E2E tests for llama w/ S2 attn
""" """
import logging
import os import os
import unittest import unittest
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

View File

@@ -2,7 +2,6 @@
E2E tests for lora llama E2E tests for lora llama
""" """
import logging
import os import os
import unittest import unittest
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from ..utils import check_model_output_exists, with_temp_dir from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = get_logger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

Some files were not shown because too many files have changed in this diff Show More