From b1f4f7a34dd4282f1a3532e2cb156ab0c99c8fa6 Mon Sep 17 00:00:00 2001 From: theobjectivedad Date: Sat, 15 Jul 2023 12:29:35 +0000 Subject: [PATCH] Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var --- scripts/finetune.py | 2 +- src/axolotl/datasets.py | 1 + src/axolotl/logging_config.py | 5 ++++- src/axolotl/monkeypatch/llama_landmark_attn.py | 3 ++- src/axolotl/prompt_strategies/pygmalion.py | 2 ++ src/axolotl/prompters.py | 1 + src/axolotl/utils/data.py | 12 +++--------- tests/test_prompt_tokenizers.py | 3 --- 8 files changed, 14 insertions(+), 15 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 19f270aa3..8696d3c9a 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -17,6 +17,7 @@ import yaml from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer +from axolotl.logging_config import configure_logging from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer @@ -24,7 +25,6 @@ from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config from axolotl.utils.wandb import setup_wandb_env_vars -from axolotl.logging_config import configure_logging project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index e70af4d27..911df8f50 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -16,6 +16,7 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy LOG = logging.getLogger("axolotl") + class TokenizedPromptDataset(IterableDataset): """ Iterable dataset that returns tokenized prompts from a stream of text files. diff --git a/src/axolotl/logging_config.py b/src/axolotl/logging_config.py index a26c35e10..1df272d5c 100644 --- a/src/axolotl/logging_config.py +++ b/src/axolotl/logging_config.py @@ -1,3 +1,6 @@ +"""Logging configuration settings""" + +import os import sys from logging.config import dictConfig from typing import Any, Dict @@ -18,7 +21,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { "stream": sys.stdout, }, }, - "root": {"handlers": ["console"], "level": "INFO"}, + "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, } diff --git a/src/axolotl/monkeypatch/llama_landmark_attn.py b/src/axolotl/monkeypatch/llama_landmark_attn.py index b83614fbb..24a98305f 100644 --- a/src/axolotl/monkeypatch/llama_landmark_attn.py +++ b/src/axolotl/monkeypatch/llama_landmark_attn.py @@ -52,6 +52,7 @@ from transformers.utils import ( logging, replace_return_docstrings, ) + LOG = logging.getLogger("axolotl") _CONFIG_FOR_DOC = "LlamaConfig" @@ -861,7 +862,7 @@ class LlamaModel(LlamaPreTrainedModel): if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( + LOG.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py index 6714ecd4b..88208f6ec 100644 --- a/src/axolotl/prompt_strategies/pygmalion.py +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -11,6 +11,8 @@ from axolotl.prompt_tokenizers import ( tokenize_prompt_default, ) +LOG = logging.getLogger("axolotl") + IGNORE_TOKEN_ID = -100 diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 5a4f9d3d0..a304bd137 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -5,6 +5,7 @@ import logging from enum import Enum, auto from typing import Generator, List, Optional, Tuple, Union +LOG = logging.getLogger("axolotl") IGNORE_TOKEN_ID = -100 diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index f34f71d4f..adfdb94e1 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -258,9 +258,7 @@ def load_tokenized_prepared_datasets( suffix = "" if ":load_" in d.type: suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?" - LOG.error( - f"unhandled prompt tokenization strategy: {d.type}. {suffix}" - ) + LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}") raise ValueError( f"unhandled prompt tokenization strategy: {d.type} {suffix}" ) @@ -271,9 +269,7 @@ def load_tokenized_prepared_datasets( samples = samples + list(d) dataset = Dataset.from_list(samples).shuffle(seed=seed) if cfg.local_rank == 0: - LOG.info( - f"Saving merged prepared dataset to disk... {prepared_ds_path}" - ) + LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") dataset.save_to_disk(prepared_ds_path) if cfg.push_dataset_to_hub: LOG.info( @@ -366,9 +362,7 @@ def load_prepare_datasets( [dataset], seq_length=max_packed_sequence_len, ) - LOG.info( - f"packing master dataset to len: {cfg.max_packed_sequence_len}" - ) + LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") dataset = Dataset.from_list(list(constant_len_dataset)) # filter out bad data diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 61935bf54..a3e4cdbdf 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -16,9 +16,6 @@ from axolotl.prompt_tokenizers import ( ShareGPTPromptTokenizingStrategy, ) from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter -from axolotl.logging_config import configure_logging - -configure_logging() LOG = logging.getLogger("axolotl")