Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var
This commit is contained in:
@@ -17,6 +17,7 @@ import yaml
|
|||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -24,7 +25,6 @@ from axolotl.utils.tokenization import check_dataset_labels
|
|||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
from axolotl.utils.validation import validate_config
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class TokenizedPromptDataset(IterableDataset):
|
class TokenizedPromptDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
Iterable dataset that returns tokenized prompts from a stream of text files.
|
Iterable dataset that returns tokenized prompts from a stream of text files.
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
"""Logging configuration settings"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from logging.config import dictConfig
|
from logging.config import dictConfig
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
@@ -18,7 +21,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
"stream": sys.stdout,
|
"stream": sys.stdout,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"root": {"handlers": ["console"], "level": "INFO"},
|
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from transformers.utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||||
@@ -861,7 +862,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
LOG.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from axolotl.prompt_tokenizers import (
|
|||||||
tokenize_prompt_default,
|
tokenize_prompt_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -258,9 +258,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
suffix = ""
|
suffix = ""
|
||||||
if ":load_" in d.type:
|
if ":load_" in d.type:
|
||||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||||
LOG.error(
|
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||||
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||||
)
|
)
|
||||||
@@ -271,9 +269,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
samples = samples + list(d)
|
samples = samples + list(d)
|
||||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
LOG.info(
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
|
||||||
)
|
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
@@ -366,9 +362,7 @@ def load_prepare_datasets(
|
|||||||
[dataset],
|
[dataset],
|
||||||
seq_length=max_packed_sequence_len,
|
seq_length=max_packed_sequence_len,
|
||||||
)
|
)
|
||||||
LOG.info(
|
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
||||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
|
||||||
)
|
|
||||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||||
|
|
||||||
# filter out bad data
|
# filter out bad data
|
||||||
|
|||||||
@@ -16,9 +16,6 @@ from axolotl.prompt_tokenizers import (
|
|||||||
ShareGPTPromptTokenizingStrategy,
|
ShareGPTPromptTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
||||||
from axolotl.logging_config import configure_logging
|
|
||||||
|
|
||||||
configure_logging()
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user