Merge pull request #276 from theobjectivedad/logging_enhancement
Logging update: added PID and formatting
This commit is contained in:
@@ -15,6 +15,9 @@ from axolotl.convert import (
|
|||||||
JsonToJsonlConverter,
|
JsonToJsonlConverter,
|
||||||
StdoutWriter,
|
StdoutWriter,
|
||||||
)
|
)
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
|
||||||
# add src to the pythonpath so we don't need to pip install this
|
# add src to the pythonpath so we don't need to pip install this
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import yaml
|
|||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
from transformers import GenerationConfig, TextStreamer
|
from transformers import GenerationConfig, TextStreamer
|
||||||
|
|
||||||
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
@@ -29,8 +30,10 @@ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
src_dir = os.path.join(project_root, "src")
|
src_dir = os.path.join(project_root, "src")
|
||||||
sys.path.insert(0, src_dir)
|
sys.path.insert(0, src_dir)
|
||||||
|
|
||||||
|
configure_logging()
|
||||||
|
LOG = logging.getLogger("axolotl.scripts")
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
|
|
||||||
@@ -212,7 +215,7 @@ def train(
|
|||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||||
logging.info(f"loading tokenizer... {tokenizer_config}")
|
LOG.info(f"loading tokenizer... {tokenizer_config}")
|
||||||
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -234,7 +237,7 @@ def train(
|
|||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
logging.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
train_dataset.select(
|
train_dataset.select(
|
||||||
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
|
||||||
@@ -243,11 +246,11 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prepare_ds_only:
|
if prepare_ds_only:
|
||||||
logging.info("Finished preparing dataset. Exiting...")
|
LOG.info("Finished preparing dataset. Exiting...")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
logging.info("loading model and peft_config...")
|
LOG.info("loading model and peft_config...")
|
||||||
model, peft_config = load_model(
|
model, peft_config = load_model(
|
||||||
cfg.base_model,
|
cfg.base_model,
|
||||||
cfg.base_model_config,
|
cfg.base_model_config,
|
||||||
@@ -258,17 +261,17 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if "merge_lora" in kwargs and cfg.adapter is not None:
|
if "merge_lora" in kwargs and cfg.adapter is not None:
|
||||||
logging.info("running merge of LoRA with base model")
|
LOG.info("running merge of LoRA with base model")
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
model.to(dtype=torch.float16)
|
model.to(dtype=torch.float16)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info("saving merged model")
|
LOG.info("saving merged model")
|
||||||
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
return
|
return
|
||||||
|
|
||||||
if cfg.inference:
|
if cfg.inference:
|
||||||
logging.info("calling do_inference function")
|
LOG.info("calling do_inference function")
|
||||||
prompter: Optional[str] = "AlpacaPrompter"
|
prompter: Optional[str] = "AlpacaPrompter"
|
||||||
if "prompter" in kwargs:
|
if "prompter" in kwargs:
|
||||||
if kwargs["prompter"] == "None":
|
if kwargs["prompter"] == "None":
|
||||||
@@ -287,12 +290,12 @@ def train(
|
|||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
if torch.__version__ >= "2" and sys.platform != "win32":
|
if torch.__version__ >= "2" and sys.platform != "win32":
|
||||||
logging.info("Compiling torch model")
|
LOG.info("Compiling torch model")
|
||||||
model = torch.compile(model)
|
model = torch.compile(model)
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
peft_config.save_pretrained(cfg.output_dir)
|
peft_config.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
@@ -308,9 +311,9 @@ def train(
|
|||||||
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
logging.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
@@ -322,7 +325,7 @@ def train(
|
|||||||
key=lambda path: int(path.split("-")[-1]),
|
key=lambda path: int(path.split("-")[-1]),
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = sorted_paths[-1]
|
resume_from_checkpoint = sorted_paths[-1]
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -336,7 +339,7 @@ def train(
|
|||||||
else:
|
else:
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||||
|
|
||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
|
|||||||
# let's check to ensure we don't truncate an item in the middle, we'll use
|
# let's check to ensure we don't truncate an item in the middle, we'll use
|
||||||
# the collators later on to pad the datasets
|
# the collators later on to pad the datasets
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class TokenizedPromptDataset(IterableDataset):
|
class TokenizedPromptDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
@@ -115,7 +117,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
|
||||||
)
|
)
|
||||||
buffer = {
|
buffer = {
|
||||||
|
|||||||
30
src/axolotl/logging_config.py
Normal file
30
src/axolotl/logging_config.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Logging configuration settings"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from logging.config import dictConfig
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
||||||
|
"version": 1,
|
||||||
|
"formatters": {
|
||||||
|
"simple": {
|
||||||
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"filters": {},
|
||||||
|
"handlers": {
|
||||||
|
"console": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"formatter": "simple",
|
||||||
|
"filters": [],
|
||||||
|
"stream": sys.stdout,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def configure_logging():
|
||||||
|
"""Configure with default logging"""
|
||||||
|
dictConfig(DEFAULT_LOGGING_CONFIG)
|
||||||
@@ -53,7 +53,7 @@ from transformers.utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
_CONFIG_FOR_DOC = "LlamaConfig"
|
_CONFIG_FOR_DOC = "LlamaConfig"
|
||||||
|
|
||||||
@@ -862,7 +862,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
LOG.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ from axolotl.prompt_tokenizers import (
|
|||||||
tokenize_prompt_default,
|
tokenize_prompt_default,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +66,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
*copy.deepcopy(res["input_ids"])
|
*copy.deepcopy(res["input_ids"])
|
||||||
][len(self.bot_prefix_token_ids) :]
|
][len(self.bot_prefix_token_ids) :]
|
||||||
else:
|
else:
|
||||||
logging.warning(f"unknown role in conversation: {role}")
|
LOG.warning(f"unknown role in conversation: {role}")
|
||||||
res = defaultdict(lambda: [])
|
res = defaultdict(lambda: [])
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
LLAMA_DEFAULT_PAD_TOKEN = "[PAD]" # nosec
|
||||||
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||||
@@ -384,7 +386,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
# everything from this is masked out from the labels
|
# everything from this is masked out from the labels
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
else:
|
else:
|
||||||
logging.warning(f"unhandled role: {part[0]}")
|
LOG.warning(f"unhandled role: {part[0]}")
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Generator, List, Optional, Tuple, Union
|
from typing import Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
|
|
||||||
|
|
||||||
@@ -241,7 +242,7 @@ class Conversation:
|
|||||||
if message:
|
if message:
|
||||||
yield (role + ":", " " + message)
|
yield (role + ":", " " + message)
|
||||||
else:
|
else:
|
||||||
logging.warning(f"role with empty message: {role}")
|
LOG.warning(f"role with empty message: {role}")
|
||||||
yield (role + ":", "")
|
yield (role + ":", "")
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ from axolotl.prompters import (
|
|||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def load_tokenized_prepared_datasets(
|
def load_tokenized_prepared_datasets(
|
||||||
tokenizer, cfg, default_dataset_prepared_path
|
tokenizer, cfg, default_dataset_prepared_path
|
||||||
@@ -73,17 +75,17 @@ def load_tokenized_prepared_datasets(
|
|||||||
if dataset:
|
if dataset:
|
||||||
...
|
...
|
||||||
elif any(prepared_ds_path.glob("*")):
|
elif any(prepared_ds_path.glob("*")):
|
||||||
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
logging.info("Prepared dataset loaded from disk...")
|
LOG.info("Prepared dataset loaded from disk...")
|
||||||
else:
|
else:
|
||||||
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||||
logging.info("Loading raw datasets...")
|
LOG.info("Loading raw datasets...")
|
||||||
|
|
||||||
if cfg.seed:
|
if cfg.seed:
|
||||||
seed = cfg.seed
|
seed = cfg.seed
|
||||||
else:
|
else:
|
||||||
logging.info("No seed provided, using default seed of 42")
|
LOG.info("No seed provided, using default seed of 42")
|
||||||
seed = 42
|
seed = 42
|
||||||
|
|
||||||
datasets = []
|
datasets = []
|
||||||
@@ -255,25 +257,21 @@ def load_tokenized_prepared_datasets(
|
|||||||
suffix = ""
|
suffix = ""
|
||||||
if ":load_" in d.type:
|
if ":load_" in d.type:
|
||||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||||
logging.error(
|
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||||
f"unhandled prompt tokenization strategy: {d.type}. {suffix}"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||||
)
|
)
|
||||||
logging.info("tokenizing, merging, and shuffling master dataset")
|
LOG.info("tokenizing, merging, and shuffling master dataset")
|
||||||
|
|
||||||
samples: List[int] = []
|
samples: List[int] = []
|
||||||
for d in datasets:
|
for d in datasets:
|
||||||
samples = samples + list(d)
|
samples = samples + list(d)
|
||||||
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
dataset = Dataset.from_list(samples).shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info(
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
f"Saving merged prepared dataset to disk... {prepared_ds_path}"
|
|
||||||
)
|
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
@@ -324,7 +322,7 @@ def load_prepare_datasets(
|
|||||||
use_auth_token = cfg.hf_use_auth_token
|
use_auth_token = cfg.hf_use_auth_token
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
@@ -338,13 +336,13 @@ def load_prepare_datasets(
|
|||||||
if dataset:
|
if dataset:
|
||||||
...
|
...
|
||||||
elif any(prepared_ds_path.glob("*")):
|
elif any(prepared_ds_path.glob("*")):
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
||||||
)
|
)
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
logging.info("Prepared packed dataset loaded from disk...")
|
LOG.info("Prepared packed dataset loaded from disk...")
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
@@ -363,9 +361,7 @@ def load_prepare_datasets(
|
|||||||
[dataset],
|
[dataset],
|
||||||
seq_length=max_packed_sequence_len,
|
seq_length=max_packed_sequence_len,
|
||||||
)
|
)
|
||||||
logging.info(
|
LOG.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
||||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
|
||||||
)
|
|
||||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||||
|
|
||||||
# filter out bad data
|
# filter out bad data
|
||||||
@@ -381,12 +377,12 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
||||||
)
|
)
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
@@ -399,7 +395,7 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||||
logging.info(
|
LOG.info(
|
||||||
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
||||||
)
|
)
|
||||||
dataset = dataset.shard(
|
dataset = dataset.shard(
|
||||||
@@ -520,7 +516,7 @@ def encode_pretraining(tokenizer, max_tokens, examples):
|
|||||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.debug(len(ret["input_ids"]))
|
LOG.debug(len(ret["input_ids"]))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from transformers import ( # noqa: F401
|
|||||||
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from peft import PeftConfig # noqa: F401
|
from peft import PeftConfig # noqa: F401
|
||||||
|
|
||||||
@@ -50,10 +52,10 @@ def load_tokenizer(
|
|||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||||
logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
if tokenizer.__class__.__name__ in [
|
if tokenizer.__class__.__name__ in [
|
||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
@@ -92,21 +94,21 @@ def load_model(
|
|||||||
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
if cfg.device not in ["mps", "cpu"] and not cfg.inference:
|
||||||
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
from axolotl.flash_attn import replace_llama_attn_with_flash_attn
|
||||||
|
|
||||||
logging.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_llama_attn_with_flash_attn()
|
replace_llama_attn_with_flash_attn()
|
||||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with xformers attention")
|
LOG.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_sdp_attention,
|
hijack_llama_sdp_attention,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
@@ -114,7 +116,7 @@ def load_model(
|
|||||||
patch_llama_with_landmark_attn,
|
patch_llama_with_landmark_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with landmark attention")
|
LOG.info("patching with landmark attention")
|
||||||
patch_llama_with_landmark_attn()
|
patch_llama_with_landmark_attn()
|
||||||
|
|
||||||
# Note: This might overwrite previous additional_special_tokens
|
# Note: This might overwrite previous additional_special_tokens
|
||||||
@@ -125,7 +127,7 @@ def load_model(
|
|||||||
replace_llama_rope_with_xpos_rope,
|
replace_llama_rope_with_xpos_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("patching with xpos rope")
|
LOG.info("patching with xpos rope")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if cfg.bf16 or cfg.bfloat16:
|
||||||
@@ -142,7 +144,7 @@ def load_model(
|
|||||||
|
|
||||||
replace_peft_model_with_int4_lora_model()
|
replace_peft_model_with_int4_lora_model()
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.exception(err)
|
LOG.exception(err)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -187,7 +189,7 @@ def load_model(
|
|||||||
if len(files) > 0:
|
if len(files) > 0:
|
||||||
model_path = str(files[0])
|
model_path = str(files[0])
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"unable to find a cached model file, this will likely fail..."
|
"unable to find a cached model file, this will likely fail..."
|
||||||
)
|
)
|
||||||
model_path = str(cache_model_path)
|
model_path = str(cache_model_path)
|
||||||
@@ -266,14 +268,14 @@ def load_model(
|
|||||||
and cfg.sequence_len > config.max_seq_len
|
and cfg.sequence_len > config.max_seq_len
|
||||||
):
|
):
|
||||||
config.max_seq_len = cfg.sequence_len
|
config.max_seq_len = cfg.sequence_len
|
||||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
elif (
|
elif (
|
||||||
hasattr(config, "max_sequence_length")
|
hasattr(config, "max_sequence_length")
|
||||||
and config.max_sequence_length
|
and config.max_sequence_length
|
||||||
and cfg.sequence_len > config.max_sequence_length
|
and cfg.sequence_len > config.max_sequence_length
|
||||||
):
|
):
|
||||||
config.max_sequence_length = cfg.sequence_len
|
config.max_sequence_length = cfg.sequence_len
|
||||||
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -285,10 +287,10 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
except Exception as err: # pylint: disable=broad-exception-caught
|
except Exception as err: # pylint: disable=broad-exception-caught
|
||||||
logging.error(
|
LOG.error(
|
||||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||||
)
|
)
|
||||||
logging.exception(err)
|
LOG.exception(err)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
@@ -307,7 +309,7 @@ def load_model(
|
|||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len >= model.config.max_position_embeddings
|
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||||
)
|
)
|
||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
@@ -316,7 +318,7 @@ def load_model(
|
|||||||
(cfg.adapter == "lora" and load_in_8bit)
|
(cfg.adapter == "lora" and load_in_8bit)
|
||||||
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
||||||
):
|
):
|
||||||
logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
model = prepare_model_for_kbit_training(
|
model = prepare_model_for_kbit_training(
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
)
|
)
|
||||||
@@ -328,7 +330,7 @@ def load_model(
|
|||||||
|
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
# Scales to half
|
# Scales to half
|
||||||
logging.info("Fitting 4bit scales and zeros to half")
|
LOG.info("Fitting 4bit scales and zeros to half")
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
||||||
type(module)
|
type(module)
|
||||||
@@ -354,7 +356,7 @@ def load_model(
|
|||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
logging.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
@@ -388,7 +390,7 @@ def load_llama_adapter(model, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.lora_model_dir:
|
if cfg.lora_model_dir:
|
||||||
logging.info("Loading pretained LORA")
|
LOG.info("Loading pretained LORA")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.lora_model_dir,
|
cfg.lora_model_dir,
|
||||||
@@ -435,7 +437,7 @@ def load_lora(model, cfg):
|
|||||||
bits = 8
|
bits = 8
|
||||||
|
|
||||||
linear_names = find_all_linear_names(bits, model)
|
linear_names = find_all_linear_names(bits, model)
|
||||||
logging.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import logging
|
|||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_labels(dataset, tokenizer):
|
def check_dataset_labels(dataset, tokenizer):
|
||||||
# the dataset is already shuffled, so let's just check the first 5 elements
|
# the dataset is already shuffled, so let's just check the first 5 elements
|
||||||
@@ -32,7 +34,7 @@ def check_example_labels(example, tokenizer):
|
|||||||
)
|
)
|
||||||
colored_tokens.append(colored_token)
|
colored_tokens.append(colored_token)
|
||||||
|
|
||||||
logging.info(" ".join(colored_tokens))
|
LOG.info(" ".join(colored_tokens))
|
||||||
logging.info("\n\n\n")
|
LOG.info("\n\n\n")
|
||||||
|
|
||||||
return " ".join(colored_tokens)
|
return " ".join(colored_tokens)
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from axolotl.utils.schedulers import (
|
|||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
@@ -324,7 +326,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
set_model_mem_id(model, tokenizer)
|
||||||
|
|
||||||
logging.info("Adding landmark attention tokens to dataset")
|
LOG.info("Adding landmark attention tokens to dataset")
|
||||||
|
|
||||||
for dataset in [train_dataset, eval_dataset]:
|
for dataset in [train_dataset, eval_dataset]:
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
@@ -11,7 +13,7 @@ def validate_config(cfg):
|
|||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
)
|
)
|
||||||
if cfg.batch_size:
|
if cfg.batch_size:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"%s\n%s",
|
"%s\n%s",
|
||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
@@ -44,10 +46,10 @@ def validate_config(cfg):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
if cfg.trust_remote_code:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,31 +68,29 @@ def validate_config(cfg):
|
|||||||
|
|
||||||
if cfg.flash_optimum is True:
|
if cfg.flash_optimum is True:
|
||||||
if cfg.adapter:
|
if cfg.adapter:
|
||||||
logging.warning(
|
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||||
"BetterTransformers probably doesn't work with PEFT adapters"
|
|
||||||
)
|
|
||||||
if cfg.fp16 or cfg.bf16:
|
if cfg.fp16 or cfg.bf16:
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
raise ValueError("AMP is not supported with BetterTransformer")
|
||||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"You should probably set bfloat16 or float16 to true to "
|
"You should probably set bfloat16 or float16 to true to "
|
||||||
"load the model in float16 for BetterTransformers"
|
"load the model in float16 for BetterTransformers"
|
||||||
)
|
)
|
||||||
if int(torch.__version__.split(".")[0]) < 2:
|
if int(torch.__version__.split(".")[0]) < 2:
|
||||||
logging.warning("torch>=2.0.0 required")
|
LOG.warning("torch>=2.0.0 required")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||||
)
|
)
|
||||||
|
|
||||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||||
):
|
):
|
||||||
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
|
|
||||||
if cfg.push_to_hub_model_id:
|
if cfg.push_to_hub_model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from axolotl.prompt_tokenizers import (
|
|||||||
)
|
)
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
||||||
|
|
||||||
logging.basicConfig(level="INFO")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user