* CLI init refactor * fix * cleanup and (partial) docs * Adding documentation and continuing cleanup (in progress) * remove finetune.py script * continued cleanup and documentation * pytest fixes * review comments * fix * Fix * typing fixes * make sure the batch dataset patcher for multipack is always loaded when handling datasets * review comments * fix --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
606 lines
20 KiB
Python
606 lines
20 KiB
Python
"""data handling specific to SFT"""
|
|
|
|
import functools
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Union
|
|
|
|
from datasets import (
|
|
Dataset,
|
|
DatasetDict,
|
|
concatenate_datasets,
|
|
load_dataset,
|
|
load_from_disk,
|
|
)
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
from axolotl.datasets import TokenizedPromptDataset
|
|
from axolotl.prompt_strategies import load
|
|
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
|
from axolotl.prompt_tokenizers import (
|
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
|
AlpacaPromptTokenizingStrategy,
|
|
AlpacaReflectionPTStrategy,
|
|
DatasetWrappingStrategy,
|
|
GPTeacherPromptTokenizingStrategy,
|
|
JeopardyPromptTokenizingStrategy,
|
|
OpenAssistantPromptTokenizingStrategy,
|
|
SummarizeTLDRPromptTokenizingStrategy,
|
|
)
|
|
from axolotl.prompters import (
|
|
AlpacaPrompter,
|
|
GPTeacherPrompter,
|
|
JeopardyPrompter,
|
|
MultipleChoiceConcisePrompter,
|
|
MultipleChoiceExplainPrompter,
|
|
Prompter,
|
|
ReflectAlpacaPrompter,
|
|
SummarizeTLDRPrompter,
|
|
UnsupportedPrompter,
|
|
)
|
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
|
from axolotl.utils.data.shared import load_dataset_w_config
|
|
from axolotl.utils.data.utils import (
|
|
deduplicate_and_log_datasets,
|
|
md5,
|
|
retry_on_request_exceptions,
|
|
)
|
|
from axolotl.utils.dict import DictDefault
|
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
|
from axolotl.utils.trainer import (
|
|
calculate_total_num_steps,
|
|
process_datasets_for_packing,
|
|
)
|
|
|
|
LOG = logging.getLogger("axolotl")
|
|
|
|
|
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
|
prompters = []
|
|
if not cfg.pretraining_dataset:
|
|
with zero_first(is_local_main_process()):
|
|
if cfg.test_datasets:
|
|
train_dataset, _, prompters = load_prepare_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
DEFAULT_DATASET_PREPARED_PATH,
|
|
split="train",
|
|
processor=processor,
|
|
)
|
|
_, eval_dataset, _ = load_prepare_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
DEFAULT_DATASET_PREPARED_PATH,
|
|
split="test",
|
|
processor=processor,
|
|
)
|
|
else:
|
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
DEFAULT_DATASET_PREPARED_PATH,
|
|
processor=processor,
|
|
)
|
|
else:
|
|
# Load streaming dataset if pretraining_dataset is given
|
|
path = cfg.pretraining_dataset
|
|
split = "train"
|
|
name = None
|
|
data_files = None
|
|
skip = 0
|
|
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
|
cfg.pretraining_dataset[0], dict
|
|
):
|
|
path = cfg.pretraining_dataset[0]["path"]
|
|
name = cfg.pretraining_dataset[0]["name"]
|
|
skip = cfg.pretraining_dataset[0]["skip"]
|
|
if "split" in cfg.pretraining_dataset[0]:
|
|
split = cfg.pretraining_dataset[0]["split"]
|
|
|
|
data_files = cfg.pretraining_dataset[0].get("data_files")
|
|
|
|
ds_wrapper_partial = functools.partial(
|
|
get_dataset_wrapper,
|
|
cfg.pretraining_dataset[0],
|
|
tokenizer,
|
|
cfg,
|
|
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
|
)
|
|
|
|
iter_ds = load_dataset(
|
|
path, streaming=True, split=split, name=name, data_files=data_files
|
|
)
|
|
if skip:
|
|
LOG.info(f"Skipping {skip} samples from the dataset")
|
|
iter_ds = iter_ds.skip(skip)
|
|
train_dataset = wrap_pretraining_dataset(
|
|
iter_ds,
|
|
tokenizer,
|
|
cfg,
|
|
ds_wrapper_partial,
|
|
max_tokens=cfg.sequence_len,
|
|
batch_size=cfg.micro_batch_size,
|
|
seed=cfg.seed or 42,
|
|
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
|
|
)
|
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
|
train_dataset = train_dataset.with_format("torch")
|
|
|
|
# Load eval dataset (non-streaming) if specified
|
|
eval_dataset = None
|
|
if cfg.test_datasets:
|
|
_, eval_dataset, _ = load_prepare_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
DEFAULT_DATASET_PREPARED_PATH,
|
|
split="test",
|
|
processor=processor,
|
|
)
|
|
|
|
if cfg.dataset_exact_deduplication:
|
|
LOG.info("Deduplication not available for pretrained datasets")
|
|
|
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
|
|
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
|
if total_eval_steps == 0:
|
|
raise ValueError(
|
|
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
|
)
|
|
|
|
if cfg.max_steps:
|
|
total_num_steps = min(
|
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
|
)
|
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
|
else:
|
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
|
|
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
|
|
|
|
|
def load_tokenized_prepared_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
default_dataset_prepared_path,
|
|
split="train",
|
|
processor=None,
|
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
|
tokenizer_name = cfg.tokenizer_config
|
|
ds_hash = str(
|
|
md5(
|
|
(
|
|
str(cfg.sequence_len)
|
|
+ "@"
|
|
+ str(cfg.sample_packing)
|
|
+ "@"
|
|
+ str(cfg.eval_sample_packing)
|
|
+ "@"
|
|
+ str(cfg.group_by_length)
|
|
+ "@"
|
|
+ "|".join(
|
|
sorted(
|
|
[
|
|
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
|
for d in cfg_datasets
|
|
]
|
|
)
|
|
)
|
|
+ "|"
|
|
+ tokenizer_name
|
|
)
|
|
)
|
|
)
|
|
prepared_ds_path = (
|
|
Path(cfg.dataset_prepared_path) / ds_hash
|
|
if cfg.dataset_prepared_path
|
|
else Path(default_dataset_prepared_path) / ds_hash
|
|
)
|
|
dataset = None
|
|
prompters = []
|
|
use_auth_token = cfg.hf_use_auth_token
|
|
try:
|
|
if cfg.push_dataset_to_hub:
|
|
LOG.info(
|
|
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
|
)
|
|
dataset = load_dataset(
|
|
cfg.push_dataset_to_hub,
|
|
ds_hash,
|
|
token=use_auth_token,
|
|
)
|
|
dataset = dataset[split]
|
|
except Exception: # pylint: disable=broad-except # nosec
|
|
pass
|
|
|
|
# pylint: disable=duplicate-code
|
|
if dataset:
|
|
# This is for the case where we already loaded a pretokenized dataset from the hub
|
|
...
|
|
elif (
|
|
cfg.dataset_prepared_path
|
|
and any(prepared_ds_path.glob("*"))
|
|
and not cfg.is_preprocess
|
|
and not cfg.skip_prepare_dataset
|
|
):
|
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
|
dataset = load_from_disk(str(prepared_ds_path))
|
|
LOG.info("Prepared dataset loaded from disk...")
|
|
else:
|
|
if cfg.push_dataset_to_hub:
|
|
LOG.info("Unable to find prepared dataset in Huggingface hub")
|
|
if cfg.is_preprocess:
|
|
LOG.info(
|
|
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
|
|
)
|
|
else:
|
|
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
|
LOG.info("Loading raw datasets...")
|
|
if not cfg.is_preprocess:
|
|
LOG.warning(
|
|
"Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
|
|
)
|
|
|
|
if cfg.seed:
|
|
seed = cfg.seed
|
|
else:
|
|
LOG.info("No seed provided, using default seed of 42")
|
|
seed = 42
|
|
|
|
datasets = []
|
|
|
|
def for_d_in_datasets(dataset_configs):
|
|
for dataset in dataset_configs:
|
|
if dataset.name and isinstance(dataset.name, list):
|
|
# load_dataset doesn't properly handle multiple named configurations
|
|
# at the same time for a given dataset
|
|
for name in dataset.name:
|
|
yield DictDefault({**dataset, "name": name})
|
|
else:
|
|
yield dataset
|
|
|
|
# pylint: disable=invalid-name
|
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
|
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
|
config_dataset, use_auth_token
|
|
)
|
|
|
|
d_base_type = d_prompt_style = None
|
|
d_type = config_dataset.type
|
|
if isinstance(d_type, str):
|
|
d_type_split = d_type.split(":")
|
|
d_base_type = d_type_split[0]
|
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
|
|
|
if isinstance(ds, DatasetDict):
|
|
if config_dataset.split and config_dataset.split in ds:
|
|
ds = ds[config_dataset.split]
|
|
elif split in ds:
|
|
ds = ds[split]
|
|
else:
|
|
raise ValueError(
|
|
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
|
)
|
|
|
|
# support for using a subset of the data
|
|
if config_dataset.shards:
|
|
shards_idx = config_dataset.get("shards_idx", 0)
|
|
ds = ds.shuffle(seed=seed).shard(
|
|
num_shards=config_dataset.shards, index=shards_idx
|
|
)
|
|
|
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
|
config_dataset=config_dataset,
|
|
tokenizer=tokenizer,
|
|
cfg=cfg,
|
|
d_base_type=d_base_type,
|
|
dataset=ds,
|
|
d_prompt_style=d_prompt_style,
|
|
processor=processor,
|
|
)
|
|
datasets.append(dataset_wrapper)
|
|
prompters.append(dataset_prompter)
|
|
|
|
if len(datasets) == 1:
|
|
dataset = datasets[0]
|
|
else:
|
|
LOG.info("merging datasets")
|
|
dataset = concatenate_datasets(datasets)
|
|
|
|
if len(datasets) > 1:
|
|
if cfg.shuffle_merged_datasets:
|
|
LOG.debug("shuffle merged datasets")
|
|
dataset = dataset.shuffle(seed=seed)
|
|
else:
|
|
LOG.debug("NOT shuffling merged datasets")
|
|
|
|
if cfg.sample_packing and not cfg.skip_prepare_dataset:
|
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
|
|
|
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
|
dataset.save_to_disk(str(prepared_ds_path))
|
|
if cfg.push_dataset_to_hub:
|
|
LOG.info(
|
|
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
|
)
|
|
dataset.push_to_hub(
|
|
cfg.push_dataset_to_hub,
|
|
ds_hash,
|
|
private=True,
|
|
)
|
|
|
|
return dataset, prompters
|
|
|
|
|
|
def load_prepare_datasets(
|
|
tokenizer: PreTrainedTokenizerBase,
|
|
cfg,
|
|
default_dataset_prepared_path,
|
|
split="train",
|
|
processor=None,
|
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
|
dataset, prompters = load_tokenized_prepared_datasets(
|
|
tokenizer,
|
|
cfg,
|
|
default_dataset_prepared_path,
|
|
split=split,
|
|
processor=processor,
|
|
)
|
|
|
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
|
LOG.info(
|
|
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
|
|
)
|
|
dataset = dataset.shard(
|
|
num_shards=cfg.dataset_shard_num,
|
|
index=cfg.dataset_shard_idx,
|
|
)
|
|
|
|
val_set_size = (
|
|
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
|
)
|
|
|
|
if split == "train" and val_set_size:
|
|
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
|
to_hash_train = (
|
|
dataset._fingerprint # pylint: disable=protected-access
|
|
+ "|"
|
|
+ str(val_set_size)
|
|
+ "|"
|
|
+ "train"
|
|
+ "|"
|
|
+ str(cfg.seed or 42)
|
|
)
|
|
to_hash_test = (
|
|
dataset._fingerprint # pylint: disable=protected-access
|
|
+ "|"
|
|
+ str(val_set_size)
|
|
+ "|"
|
|
+ "test"
|
|
+ "|"
|
|
+ str(cfg.seed or 42)
|
|
)
|
|
train_fingerprint = md5(to_hash_train)
|
|
test_fingerprint = md5(to_hash_test)
|
|
if cfg.dataset_exact_deduplication:
|
|
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
|
|
dataset = dataset.train_test_split(
|
|
test_size=val_set_size,
|
|
shuffle=False,
|
|
seed=cfg.seed or 42,
|
|
train_new_fingerprint=train_fingerprint,
|
|
test_new_fingerprint=test_fingerprint,
|
|
)
|
|
|
|
train_dataset = dataset["train"]
|
|
eval_dataset = dataset["test"]
|
|
elif split == "test":
|
|
if cfg.dataset_exact_deduplication:
|
|
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
|
else:
|
|
eval_dataset = dataset
|
|
train_dataset = None
|
|
else:
|
|
if cfg.dataset_exact_deduplication:
|
|
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
|
else:
|
|
train_dataset = dataset
|
|
eval_dataset = None
|
|
return train_dataset, eval_dataset, prompters
|
|
|
|
|
|
def get_dataset_wrapper(
|
|
config_dataset,
|
|
tokenizer,
|
|
cfg,
|
|
d_base_type,
|
|
dataset,
|
|
d_prompt_style=None,
|
|
processor=None, # pylint: disable=unused-argument
|
|
):
|
|
dataset_wrapper = None
|
|
dataset_prompter = None
|
|
|
|
ds_kwargs = {
|
|
"process_count": cfg.dataset_processes,
|
|
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
|
}
|
|
|
|
LOG.info(
|
|
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
|
)
|
|
|
|
if (
|
|
isinstance(dataset, Dataset)
|
|
and "input_ids" in dataset.features
|
|
and "attention_mask" in dataset.features
|
|
and "labels" in dataset.features
|
|
):
|
|
# dataset is already tokenized, just drop it straight in
|
|
dataset_prompter = UnsupportedPrompter()
|
|
dataset_wrapper = dataset
|
|
elif isinstance(config_dataset.type, DictDefault):
|
|
ds_strategy = load(
|
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
|
)
|
|
dataset_prompter = UnsupportedPrompter()
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
elif cfg.skip_prepare_dataset:
|
|
dataset_wrapper = dataset
|
|
elif ds_strategy := config_dataset.type.startswith(
|
|
"bradley_terry"
|
|
) and bradley_terry_load(
|
|
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
|
):
|
|
dataset_prompter = UnsupportedPrompter()
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
elif ds_strategy := load(
|
|
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
|
|
):
|
|
if isinstance(ds_strategy, DatasetWrappingStrategy):
|
|
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
|
else:
|
|
dataset_prompter = UnsupportedPrompter()
|
|
dataset_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
elif d_base_type == "alpaca":
|
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "explainchoice":
|
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "concisechoice":
|
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "summarizetldr":
|
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "jeopardy":
|
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
|
ds_strategy = JeopardyPromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "oasst":
|
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
|
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "gpteacher":
|
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
|
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
elif d_base_type == "reflection":
|
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
|
ds_strategy = AlpacaReflectionPTStrategy(
|
|
dataset_prompter,
|
|
tokenizer,
|
|
cfg.train_on_inputs,
|
|
cfg.sequence_len,
|
|
)
|
|
ds_wrapper = TokenizedPromptDataset(
|
|
ds_strategy,
|
|
dataset,
|
|
**ds_kwargs,
|
|
)
|
|
dataset_wrapper = ds_wrapper
|
|
else:
|
|
suffix = ""
|
|
if ":load_" in config_dataset.type:
|
|
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
|
LOG.error(
|
|
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
|
)
|
|
raise ValueError(
|
|
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
|
)
|
|
|
|
return dataset_wrapper, dataset_prompter
|