Rank 0-only logging (#2608)
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
@@ -38,6 +38,7 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Various checks for Axolotl CLI."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -8,7 +7,9 @@ from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def check_accelerate_default_config() -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Configuration loading and processing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -22,11 +21,12 @@ from axolotl.utils.config import (
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
@@ -119,12 +119,12 @@ def choose_config(path: Path) -> str:
|
||||
)
|
||||
|
||||
if len(yaml_files) == 1:
|
||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
|
||||
return str(yaml_files[0])
|
||||
|
||||
print("Choose a YAML file:")
|
||||
LOG.info("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
print(f"{idx + 1}. {file}")
|
||||
LOG.info(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
@@ -133,9 +133,9 @@ def choose_config(path: Path) -> str:
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = str(yaml_files[choice - 1])
|
||||
else:
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
LOG.info("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
print("Invalid input. Please enter a number.")
|
||||
LOG.info("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CLI to run evaluation on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -17,8 +16,9 @@ from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""CLI to run inference on a trained model."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
@@ -22,8 +21,9 @@ from axolotl.utils.chat_templates import (
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_multi_line_input() -> str:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
@@ -31,8 +30,11 @@ from axolotl.cli.utils import (
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
@@ -177,7 +179,7 @@ def train(
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CLI to merge a trained LoRA into a base model."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -13,8 +12,9 @@ from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -27,8 +26,9 @@ from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -20,9 +19,10 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
CLI to post-training quantize a model using torchao
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -11,9 +10,10 @@ from transformers import AutoModelForCausalLM
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def do_quantize(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -22,8 +21,6 @@ from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,6 @@ import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
@@ -23,8 +22,9 @@ from transformers import (
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Dataset loading utilities."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
@@ -14,10 +13,11 @@ from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -156,7 +156,6 @@ class Messages(BaseModel):
|
||||
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||
]
|
||||
if new_pending_inputs != pending_input_ids:
|
||||
# logging.warning("tokenization mismatch from concatenation.")
|
||||
pending_input_ids = new_pending_inputs
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
|
||||
@@ -19,7 +19,6 @@ import abc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -88,6 +87,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers, RLType
|
||||
|
||||
try:
|
||||
@@ -95,7 +95,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from functools import wraps
|
||||
@@ -34,9 +33,10 @@ from axolotl.core.trainers.utils import (
|
||||
sanitize_kwargs_for_ds_tagging,
|
||||
sanitize_kwargs_for_tagging,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
@@ -13,9 +12,10 @@ from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class GRPOStrategy:
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
"""Module for Axolotl trainer optimizer mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from transformers.trainer import Trainer
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
import smdistributed.modelparallel.torch as smp
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class OptimizerMixin(Trainer):
|
||||
|
||||
@@ -6,7 +6,6 @@ See https://github.com/huggingface/transformers/pull/37162
|
||||
TODO: Remove when upstream added PR to release
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
@@ -17,7 +16,9 @@ from transformers.trainer import safe_globals
|
||||
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||
from transformers.training_args import ParallelMode
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class RngLoaderMixin(Trainer):
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"""Module for Axolotl trainer scheduler mixin"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler, OneCycleLR
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schedulers import (
|
||||
RexLR,
|
||||
get_cosine_schedule_with_min_lr,
|
||||
@@ -14,7 +13,7 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_warmup_decay_constant,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
@@ -80,13 +79,15 @@ class SchedulerMixin(Trainer):
|
||||
self.lr_scheduler = RexLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
min_lr=0 if not use_cosine_min_lr else (
|
||||
self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||
total_steps=num_training_steps,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
)
|
||||
elif use_cosine_quadratic:
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
LOG.warning(
|
||||
"Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
@@ -115,9 +116,11 @@ class SchedulerMixin(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||
else:
|
||||
if use_cosine_quadratic:
|
||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||
|
||||
if use_cosine_min_lr:
|
||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
LOG.warning(
|
||||
"axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||
|
||||
return self.lr_scheduler # type: ignore
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Module containing Dataset functionality"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
# We want this to be a wrapper for an existing dataset that we have loaded
|
||||
@@ -15,7 +16,7 @@ from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||
# the collators later on to pad the datasets
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenizedPromptDataset(Dataset):
|
||||
|
||||
@@ -22,7 +22,6 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||
|
||||
from peft import PeftModel
|
||||
@@ -31,6 +30,9 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||
from transformers import PreTrainedModel, Trainer
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.common.datasets import TrainDatasetMeta
|
||||
@@ -331,12 +333,12 @@ class PluginManager:
|
||||
ImportError: If the plugin module cannot be imported.
|
||||
"""
|
||||
try:
|
||||
logging.info(f"Attempting to load plugin: {plugin_name}")
|
||||
LOG.info(f"Attempting to load plugin: {plugin_name}")
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
logging.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
def get_input_args(self) -> list[str]:
|
||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||
|
||||
@@ -19,17 +19,16 @@ Cut Cross Entropy is an optimized implementation of cross entropy loss
|
||||
from Apple's ML team.
|
||||
"""
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
@@ -76,10 +75,9 @@ class CutCrossEntropyPlugin(BasePlugin):
|
||||
cce_patch,
|
||||
)
|
||||
|
||||
if is_main_process(use_environ=True):
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
)
|
||||
LOG.info(
|
||||
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||
)
|
||||
|
||||
# The patch checks model_type internally
|
||||
cce_patch(cfg.model_config_type)
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
"""
|
||||
Module for handling Cut Cross Entropy input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class CutCrossEntropyArgs(BaseModel):
|
||||
|
||||
@@ -2,15 +2,15 @@
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .optimizer import gradfilter_ema
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class GrokfastCallbackHandler(TrainerCallback):
|
||||
|
||||
@@ -19,16 +19,15 @@ Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .utils import patch_with_compile_disable
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
@@ -85,10 +84,7 @@ class LigerPlugin(BasePlugin):
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
if is_main_process(use_environ=True):
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
)
|
||||
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
||||
apply_liger_fn(**kwargs)
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
@@ -124,9 +120,9 @@ class LigerPlugin(BasePlugin):
|
||||
if cfg.liger_rope:
|
||||
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
||||
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_glu_activation:
|
||||
logging.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
@@ -186,6 +182,6 @@ class LigerPlugin(BasePlugin):
|
||||
swiglu=cfg.liger_glu_activation,
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
LOG.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
)
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
|
||||
@@ -3,7 +3,6 @@ Sparse Finetuning plugin for Axolotl — enables handling of sparse neural netwo
|
||||
by maintaining masks for zero weights during training.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
@@ -16,11 +15,12 @@ from transformers.trainer_callback import TrainerCallback, TrainerControl, Train
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
P = ParamSpec("P") # Params for generic function signatures
|
||||
R = TypeVar("R") # Return type for generic function signatures
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.llm_compressor")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMCompressorCallbackHandler(TrainerCallback):
|
||||
|
||||
@@ -17,14 +17,16 @@ Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
|
||||
unfrozen_parameters = {}
|
||||
@@ -83,17 +85,17 @@ class SpectrumPlugin(BasePlugin):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
|
||||
LOG.warning(f"Failed to read SNR data from {snr_path}: {exc}")
|
||||
|
||||
if not snr_data:
|
||||
try:
|
||||
snr_data = requests.get(snr_url, timeout=60).json()
|
||||
except requests.exceptions.RequestException as exc:
|
||||
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
|
||||
LOG.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
|
||||
return
|
||||
# also catch json parsing errors
|
||||
except json.JSONDecodeError as exc:
|
||||
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
|
||||
LOG.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
|
||||
return
|
||||
|
||||
unfrozen_parameters = _generate_unfrozen_params_yaml(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Adapter loading functionality, including LoRA / QLoRA and associated utils"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
from typing import Any
|
||||
@@ -21,8 +20,9 @@ from transformers import PreTrainedModel
|
||||
|
||||
from axolotl.loaders.utils import get_linear_embedding_layers
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def setup_quantized_meta_for_peft(model: torch.nn.Module):
|
||||
|
||||
@@ -3,7 +3,6 @@ models.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import cached_property
|
||||
@@ -47,10 +46,11 @@ from axolotl.utils.distributed import (
|
||||
get_device_count,
|
||||
get_device_type,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model_quant
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ Applies pre- and post-model load patches for various fixes and optimizations.
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
import addict
|
||||
@@ -17,8 +16,9 @@ from axolotl.monkeypatch.multipack import (
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Processor loading functionality for multi-modal models"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import transformers
|
||||
@@ -10,8 +9,9 @@ from transformers import (
|
||||
)
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tokenizer loading functionality and associated utils"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import transformers
|
||||
@@ -19,8 +18,9 @@ from axolotl.utils.distributed import (
|
||||
is_local_main_process,
|
||||
is_main_process,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
PLUGIN_MANAGER = PluginManager.get_instance()
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Utilities for axolotl.loaders module"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import addict
|
||||
@@ -9,8 +8,9 @@ import torch
|
||||
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def get_module_class_from_name(
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
monkeypatch for accelerate fsdp2 fix when modifying ordereddict during interation, and saving full state dicts
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dict):
|
||||
|
||||
@@ -3,7 +3,6 @@ Flash attention monkey patch for cerebras btlm model
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -11,7 +10,9 @@ from accelerate import init_empty_weights
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||
|
||||
@@ -18,7 +18,6 @@ DISCO - DIsk-based Storage and Checkpointing with Optimized prefetching
|
||||
|
||||
import atexit
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import shutil
|
||||
@@ -32,11 +31,13 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
# Setup logger
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DiskOffloadManager:
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -25,6 +24,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -41,7 +41,7 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def is_xformers_available() -> bool:
|
||||
@@ -612,9 +612,10 @@ def generate_qkv(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
@@ -627,9 +628,10 @@ def generate_qkv(
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
|
||||
"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -11,10 +10,14 @@ import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
logging.error("xformers not found! Please install it before trying to use it.")
|
||||
LOG.error("xformers not found! Please install it before trying to use it.")
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
|
||||
@@ -7,7 +7,6 @@ import types
|
||||
from typing import Generator, Tuple, Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
@@ -20,6 +19,7 @@ from axolotl.kernels.lora import (
|
||||
)
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -28,8 +27,9 @@ from transformers.models.mistral.modeling_mistral import (
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
@@ -359,9 +359,10 @@ def generate_qkv(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
def output_pad_fn(output_unpad):
|
||||
return pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
@@ -374,9 +375,10 @@ def generate_qkv(
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
def output_pad_fn(output_unpad):
|
||||
return rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
|
||||
@@ -3,14 +3,14 @@ Patch prepare_model_for_kbit_training to not upcast everything
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import peft
|
||||
|
||||
import axolotl
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_PREPARE_CODE = """
|
||||
for param in model.parameters():
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
@@ -27,8 +26,9 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -32,11 +32,11 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
monkeypatch for Trainer _get_learning_rate method
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO remove this patch once https://github.com/huggingface/transformers/pull/37881 is included in a release
|
||||
|
||||
@@ -3,13 +3,13 @@ allow adding additional kwargs to Accelerator init
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
# create accelerator object
|
||||
|
||||
@@ -3,13 +3,13 @@ fix for FSDP2 evals when using torch.compile
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
model.eval()
|
||||
|
||||
@@ -3,13 +3,13 @@ fix for FSDP optimizer save in trainer w 4.47.0
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_TRAINER_CODE = """
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
see https://github.com/huggingface/transformers/pull/35834
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def fixed_fa_peft_integration_check(
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states)
|
||||
@@ -133,7 +133,7 @@ def patch_self_attn_lora():
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
self_attn_lora_patched = True
|
||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||
LOG.info("patching unsloth attn lora")
|
||||
LlamaFlashAttention2.forward = (
|
||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -153,7 +153,7 @@ def integrate_rope_embeddings():
|
||||
):
|
||||
return fast_rope_embedding(q, k, cos, sin)
|
||||
|
||||
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
|
||||
LOG.info("patching unsloth RoPE embeddings")
|
||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||
else:
|
||||
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
|
||||
|
||||
|
||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
@@ -215,7 +215,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||
else:
|
||||
layer.self_attn.apply_qkv = original_apply_qkv
|
||||
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
|
||||
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
|
||||
if cfg.unsloth_lora_o:
|
||||
layer_modules = [
|
||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||
@@ -234,9 +234,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
layer.self_attn.apply_o = apply_lora_o
|
||||
else:
|
||||
layer.self_attn.apply_o = original_apply_o
|
||||
LOG.warning(
|
||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||
)
|
||||
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
|
||||
|
||||
|
||||
def patch_unsloth_layernorm():
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
@@ -10,7 +9,9 @@ from torch import Tensor
|
||||
from transformers import ProcessorMixin
|
||||
from transformers.image_utils import load_image
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class ProcessingStrategy:
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
|
||||
@@ -3,9 +3,10 @@ module for base dataset transform strategies
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load(strategy, cfg, module_base=None, **kwargs):
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Bradley-Terry model with chat template prompt strategy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
@@ -10,10 +9,11 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplateStrategy,
|
||||
)
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
|
||||
LOG.setLevel(logging.INFO)
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
|
||||
|
||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
@@ -44,7 +44,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}"
|
||||
)
|
||||
|
||||
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
|
||||
@@ -62,7 +62,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
LOG.warning(
|
||||
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
|
||||
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}"
|
||||
)
|
||||
|
||||
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
HF Chat Templates prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
|
||||
@@ -13,11 +12,12 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG.setLevel(logging.INFO)
|
||||
LOG = get_logger(__name__)
|
||||
LOG.setLevel("INFO")
|
||||
|
||||
|
||||
class ChatTemplatePrompter(Prompter):
|
||||
@@ -378,7 +378,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
add_generation_prompt=True,
|
||||
images=images,
|
||||
)
|
||||
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
|
||||
tokenized_res = self.prompter.build_prompt(
|
||||
turns, images=images
|
||||
) # type: ignore
|
||||
tokenized_prompt = {}
|
||||
if isinstance(tokenized_res, list):
|
||||
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
||||
@@ -555,8 +557,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
and turns[0].get("role") == "system"
|
||||
and (
|
||||
"mistral" in self.tokenizer.name_or_path.lower()
|
||||
# gemma3 uses gemma tokenizer
|
||||
or "gemma" in self.tokenizer.name_or_path.lower()
|
||||
or "gemma"
|
||||
in self.tokenizer.name_or_path.lower() # gemma3 uses gemma tokenizer
|
||||
)
|
||||
):
|
||||
return -1, -1
|
||||
|
||||
@@ -24,12 +24,14 @@ For a custom system message, the first "from" can be "system" (followed by alter
|
||||
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -129,7 +131,7 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
if cur_len < self.sequence_len:
|
||||
if cur_len != total_len:
|
||||
target[:] = IGNORE_TOKEN_ID
|
||||
logging.warning(
|
||||
LOG.warning(
|
||||
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
||||
f" (ignored)"
|
||||
)
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Module containing the MetharmenPromptTokenizingStrategy and MetharmePrompter class"""
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Generator, List, Tuple
|
||||
|
||||
@@ -10,8 +9,9 @@ from axolotl.prompt_tokenizers import (
|
||||
parse_tokenized_to_result,
|
||||
tokenize_prompt_default,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
IGNORE_TOKEN_ID = -100
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LLAMA_DEFAULT_PAD_TOKEN = "<pad>" # nosec
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Module containing prompters"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -37,6 +36,7 @@ from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContext
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import cleanup_distributed
|
||||
from axolotl.utils.freeze import freeze_layers_except
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
@@ -45,7 +45,7 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
@@ -64,9 +64,7 @@ def setup_model_and_tokenizer(
|
||||
`None`), and processor (if multimodal, else `None`).
|
||||
"""
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
)
|
||||
LOG.debug(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
# Load processor for multimodal models if needed
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
from shutil import copyfile
|
||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
||||
is_main_process,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,7 +50,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class EvalFirstStepCallback(
|
||||
@@ -753,7 +753,14 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
].append(pred_step_text)
|
||||
row_index += 1
|
||||
if logger == "wandb":
|
||||
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
|
||||
# type: ignore[attr-defined]
|
||||
wandb.run.log(
|
||||
{
|
||||
f"{name} - Predictions vs Ground Truth": pd.DataFrame(
|
||||
table_data
|
||||
)
|
||||
}
|
||||
)
|
||||
elif logger == "mlflow" and is_mlflow_available():
|
||||
import mlflow
|
||||
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Comet module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import comet_ml
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||
|
||||
@@ -6,17 +6,18 @@ Arxiv: https://arxiv.org/abs/2403.17919
|
||||
License: Apache 2.0
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainer
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks.lisa")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""MLFlow module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from tempfile import NamedTemporaryFile
|
||||
@@ -10,11 +9,12 @@ import mlflow
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def should_log_artifacts() -> bool:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""QAT Callback for HF Causal Trainer"""
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from torch import nn
|
||||
@@ -8,9 +7,10 @@ from torchao.quantization.qat.embedding import FakeQuantizedEmbedding
|
||||
from torchao.quantization.qat.linear import FakeQuantizedLinear
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.quantization import QATConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def toggle_fake_quant(mod: nn.Module, enable: bool):
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,11 +1,11 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
COMET_ENV_MAPPING_OVERRIDE = {
|
||||
"comet_mode": "COMET_START_MODE",
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for working with config dicts"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,13 +14,14 @@ from axolotl.loaders import MULTIMODAL_AUTO_MODEL_MAPPING
|
||||
from axolotl.loaders.utils import load_model_config
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
|
||||
def choose_device(cfg):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""data handling specific to pretraining"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
@@ -11,10 +10,11 @@ from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""data handling specific to DPO"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
@@ -18,9 +17,10 @@ from axolotl.utils.data.shared import datasets_w_name_generator, load_dataset_w_
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_path(ds_hash, cfg):
|
||||
@@ -217,7 +217,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -226,7 +226,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""data handling specific to SFT"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -54,12 +53,13 @@ from axolotl.utils.data.utils import (
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
@@ -182,10 +182,9 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
@@ -331,12 +330,12 @@ def load_tokenized_prepared_datasets(
|
||||
if len(datasets) == 1:
|
||||
dataset = datasets[0]
|
||||
else:
|
||||
LOG.info("merging datasets")
|
||||
LOG.info("Merging datasets...")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
if len(datasets) > 1:
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.debug("shuffle merged datasets")
|
||||
LOG.debug("Shuffling merged datasets...")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
else:
|
||||
LOG.debug("NOT shuffling merged datasets")
|
||||
@@ -426,7 +425,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
to_hash_test = (
|
||||
dataset._fingerprint # pylint: disable=protected-access
|
||||
@@ -435,7 +434,7 @@ def load_prepare_datasets(
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
+ str(cfg.seed or 42)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from enum import Enum
|
||||
|
||||
@@ -12,10 +11,11 @@ import requests
|
||||
from datasets import Dataset, IterableDataset
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||
from axolotl.utils.trainer import drop_long_seq
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class RetryStrategy(Enum):
|
||||
|
||||
62
src/axolotl/utils/logging.py
Normal file
62
src/axolotl/utils/logging.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
logging helpers to only log on main process
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
# Adapted from Accelerate
|
||||
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/logging.py
|
||||
|
||||
|
||||
class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
"""
|
||||
logger adapter for distributed logging, specifically to only log on main process
|
||||
"""
|
||||
|
||||
def __init__(self, logger, use_environ=False, extra=None):
|
||||
super().__init__(logger, extra)
|
||||
self.use_environ = use_environ
|
||||
|
||||
@staticmethod
|
||||
def _should_log(main_process_only, use_environ=False):
|
||||
return not main_process_only or (
|
||||
main_process_only and is_main_process(use_environ=use_environ)
|
||||
)
|
||||
|
||||
def log(self, level, msg, *args, **kwargs):
|
||||
use_environ = kwargs.pop("use_environ", self.use_environ)
|
||||
main_process_only = kwargs.pop("main_process_only", True)
|
||||
kwargs.setdefault("stacklevel", 2)
|
||||
|
||||
if self.isEnabledFor(level) and self._should_log(
|
||||
main_process_only, use_environ=use_environ
|
||||
):
|
||||
msg, kwargs = self.process(msg, kwargs)
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
@functools.lru_cache(maxsize=10)
|
||||
def warning_once(self, *args, **kwargs):
|
||||
"""
|
||||
This method is identical to `logger.warning()`, but will emit the warning with the same message only once
|
||||
|
||||
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
|
||||
cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
|
||||
switch to another type of cache that includes the caller frame information in the hashing function.
|
||||
"""
|
||||
self.warning(*args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(
|
||||
name: str, log_level: str | None = None, use_environ: bool = False
|
||||
) -> MultiProcessAdapter:
|
||||
if log_level is None:
|
||||
log_level = os.environ.get("AXOLOTL_LOG_LEVEL", None)
|
||||
logger = logging.getLogger(name)
|
||||
if log_level is not None:
|
||||
logger.setLevel(log_level.upper())
|
||||
logger.root.setLevel(log_level.upper())
|
||||
return MultiProcessAdapter(logger, use_environ=use_environ, extra={})
|
||||
@@ -2,8 +2,6 @@
|
||||
Utilities for quantization including QAT and PTQ using torchao.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchao.core.config import AOBaseConfig
|
||||
@@ -25,8 +23,6 @@ from torchao.quantization.quant_api import (
|
||||
|
||||
from axolotl.utils.schemas.enums import TorchIntDType
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ptq_config(
|
||||
weight_dtype: TorchIntDType,
|
||||
|
||||
@@ -3,7 +3,6 @@ Multipack Batch Sampler - An efficient batch sampler for packing variable-length
|
||||
into fixed-capacity batches to optimize memory usage and training throughput.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from multiprocessing import cpu_count, get_context
|
||||
@@ -14,9 +13,9 @@ import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler, SequentialSampler
|
||||
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG.setLevel(logging.INFO)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@numba.njit
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
@@ -18,6 +17,7 @@ from pydantic import (
|
||||
)
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.datasets import (
|
||||
DatasetConfig,
|
||||
DPODataset,
|
||||
@@ -49,7 +49,7 @@ from axolotl.utils.schemas.training import HyperparametersConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
from axolotl.utils.schemas.vllm import VllmConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
|
||||
@@ -64,6 +64,7 @@ class ChatTemplate(str, Enum):
|
||||
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
|
||||
aya = "aya" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""Pydantic models for Axolotl integrations"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class MLFlowConfig(BaseModel):
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Pydantic models for model input / output, etc. configuration"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
|
||||
|
||||
class ModelInputConfig(BaseModel):
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
"""Pydantic models for training hyperparameters"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class LrGroup(BaseModel):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Utilities for Axolotl Pydantic models"""
|
||||
|
||||
import logging
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
import logging
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def check_dataset_labels(
|
||||
|
||||
@@ -22,7 +22,7 @@ from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger("axolotl")
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -402,7 +402,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
.apply(len)
|
||||
.values
|
||||
)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens:_}", main_process_only=True)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens:_}")
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
@@ -420,10 +420,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.debug(
|
||||
f"`total_supervised_tokens: {total_supervised_tokens:_}`",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.debug(f"`total_supervised_tokens: {total_supervised_tokens:_}`")
|
||||
if update:
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
@@ -448,8 +445,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
* cfg.sequence_parallel_degree
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}",
|
||||
main_process_only=True,
|
||||
f"total_num_tokens: {cfg.total_num_tokens:_}, total_num_steps: {total_num_steps:_}"
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention and not cfg.multipack_real_batches:
|
||||
@@ -478,7 +474,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}")
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
@@ -500,10 +496,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
if update:
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.debug(
|
||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.debug(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
@@ -513,7 +506,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -14,10 +13,11 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu eval
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,10 +10,11 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,10 +12,11 @@ from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from tests.e2e.utils import check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu lora tinyllama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -15,10 +14,11 @@ from packaging import version
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu qwen2
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -12,8 +11,9 @@ from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
LOG = get_logger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multigpu post-training use Ray Train
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -11,10 +10,11 @@ import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from tests.e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
LOG = get_logger(__name__)
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for multipack fft llama using 4d attention masks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -11,10 +10,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for falcon
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for llama w/ S2 attn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -13,10 +12,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -14,10 +13,11 @@ from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
LOG = get_logger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user