|
|
|
|
@@ -44,24 +44,11 @@ from axolotl.utils.trainer import (
|
|
|
|
|
LOG = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _determine_streaming_mode(cfg: DictDefault) -> bool:
|
|
|
|
|
"""Determine if we should use streaming mode based on config."""
|
|
|
|
|
if cfg.streaming is not None:
|
|
|
|
|
return cfg.streaming
|
|
|
|
|
|
|
|
|
|
# Default to streaming for pretraining datasets
|
|
|
|
|
if cfg.pretraining_dataset:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
|
|
|
|
def prepare_datasets(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
|
|
|
|
processor: ProcessorMixin | None = None,
|
|
|
|
|
preprocess_iterable: bool = False,
|
|
|
|
|
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
|
|
|
|
|
"""Prepare training and evaluation datasets based on configuration.
|
|
|
|
|
|
|
|
|
|
@@ -69,30 +56,19 @@ def prepare_datasets(
|
|
|
|
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
|
|
|
tokenizer: Tokenizer to use for processing text.
|
|
|
|
|
processor: Optional processor for multimodal datasets.
|
|
|
|
|
preprocess_iterable: Whether to use iterable preprocessing.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
|
|
|
|
"""
|
|
|
|
|
# Determine streaming mode from config
|
|
|
|
|
streaming_mode = _determine_streaming_mode(cfg)
|
|
|
|
|
|
|
|
|
|
# Override preprocess_iterable parameter with streaming config
|
|
|
|
|
if streaming_mode:
|
|
|
|
|
preprocess_iterable = True
|
|
|
|
|
|
|
|
|
|
if cfg.pretraining_dataset:
|
|
|
|
|
return _prepare_pretraining_dataset(
|
|
|
|
|
cfg, tokenizer, processor, preprocess_iterable
|
|
|
|
|
)
|
|
|
|
|
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
|
|
|
|
|
return _prepare_pretraining_dataset(cfg, tokenizer, processor)
|
|
|
|
|
return _prepare_standard_dataset(cfg, tokenizer, processor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _prepare_standard_dataset(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
|
|
|
|
processor: ProcessorMixin | None,
|
|
|
|
|
preprocess_iterable: bool,
|
|
|
|
|
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
|
|
|
|
|
"""Prepare standard (non-pretraining) datasets."""
|
|
|
|
|
|
|
|
|
|
@@ -103,7 +79,6 @@ def _prepare_standard_dataset(
|
|
|
|
|
cfg,
|
|
|
|
|
split="train",
|
|
|
|
|
processor=processor,
|
|
|
|
|
preprocess_iterable=preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Overwrite eval_dataset if test data exists
|
|
|
|
|
@@ -113,7 +88,6 @@ def _prepare_standard_dataset(
|
|
|
|
|
cfg,
|
|
|
|
|
split="test",
|
|
|
|
|
processor=processor,
|
|
|
|
|
preprocess_iterable=preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return train_dataset, eval_dataset, prompters
|
|
|
|
|
@@ -159,7 +133,6 @@ def _prepare_pretraining_dataset(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
|
|
|
|
processor: ProcessorMixin | None,
|
|
|
|
|
preprocess_iterable: bool,
|
|
|
|
|
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
|
|
|
|
"""
|
|
|
|
|
Prepare dataset for pretraining mode.
|
|
|
|
|
@@ -180,7 +153,6 @@ def _prepare_pretraining_dataset(
|
|
|
|
|
cfg,
|
|
|
|
|
split="test",
|
|
|
|
|
processor=processor,
|
|
|
|
|
preprocess_iterable=preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cfg.dataset_exact_deduplication:
|
|
|
|
|
@@ -283,7 +255,6 @@ def _load_tokenized_prepared_datasets(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
split: Literal["train", "test"] = "train",
|
|
|
|
|
processor: ProcessorMixin | None = None,
|
|
|
|
|
preprocess_iterable: bool = False,
|
|
|
|
|
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
|
|
|
|
|
"""Load or create tokenized and prepared datasets for training or testing.
|
|
|
|
|
|
|
|
|
|
@@ -292,7 +263,6 @@ def _load_tokenized_prepared_datasets(
|
|
|
|
|
cfg: Configuration object.
|
|
|
|
|
split: Dataset split to load ('train' or 'test').
|
|
|
|
|
processor: Optional processor for multimodal datasets.
|
|
|
|
|
preprocess_iterable: Whether to use iterable preprocessing.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (dataset, prompters list).
|
|
|
|
|
@@ -323,7 +293,6 @@ def _load_tokenized_prepared_datasets(
|
|
|
|
|
tokenizer,
|
|
|
|
|
split,
|
|
|
|
|
processor,
|
|
|
|
|
preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return dataset, prompters
|
|
|
|
|
@@ -335,7 +304,6 @@ def _load_raw_datasets(
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
|
|
|
|
split: str,
|
|
|
|
|
processor: ProcessorMixin | None = None,
|
|
|
|
|
preprocess_iterable: bool = False,
|
|
|
|
|
) -> tuple[Dataset, list[Prompter | None]]:
|
|
|
|
|
"""Load, process, merge, and save raw datasets."""
|
|
|
|
|
LOG.info("Loading raw datasets...", main_process_only=False)
|
|
|
|
|
@@ -356,7 +324,6 @@ def _load_raw_datasets(
|
|
|
|
|
split=split,
|
|
|
|
|
seed=cfg.seed,
|
|
|
|
|
processor=processor,
|
|
|
|
|
preprocess_iterable=preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
datasets.append(dataset_wrapper)
|
|
|
|
|
prompters.append(dataset_prompter)
|
|
|
|
|
@@ -388,15 +355,11 @@ def _load_and_process_single_dataset(
|
|
|
|
|
split: str,
|
|
|
|
|
seed: int,
|
|
|
|
|
processor: ProcessorMixin | None = None,
|
|
|
|
|
preprocess_iterable: bool = False,
|
|
|
|
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
|
|
|
|
"""Load and process a single dataset based on the passed config."""
|
|
|
|
|
# Load the dataset
|
|
|
|
|
dataset = load_dataset_with_config(
|
|
|
|
|
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
|
|
|
|
|
dataset_config, cfg.hf_use_auth_token, cfg.streaming
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Parse dataset type
|
|
|
|
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
|
|
|
|
|
|
|
|
|
# Select the appropriate split
|
|
|
|
|
@@ -515,7 +478,6 @@ def _load_and_prepare_datasets(
|
|
|
|
|
cfg: DictDefault,
|
|
|
|
|
split: Literal["train", "test"] = "train",
|
|
|
|
|
processor: ProcessorMixin | None = None,
|
|
|
|
|
preprocess_iterable: bool = False,
|
|
|
|
|
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
|
|
|
|
|
"""Load and prepare datasets with optional validation split and sharding.
|
|
|
|
|
|
|
|
|
|
@@ -524,7 +486,6 @@ def _load_and_prepare_datasets(
|
|
|
|
|
cfg: Configuration object.
|
|
|
|
|
split: Dataset split to load ('train' or 'test').
|
|
|
|
|
processor: Optional processor for multimodal datasets.
|
|
|
|
|
preprocess_iterable: Whether to use iterable preprocessing.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (train_dataset, eval_dataset, prompters).
|
|
|
|
|
@@ -535,7 +496,6 @@ def _load_and_prepare_datasets(
|
|
|
|
|
cfg,
|
|
|
|
|
split=split,
|
|
|
|
|
processor=processor,
|
|
|
|
|
preprocess_iterable=preprocess_iterable,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Apply dataset sharding if configured using shared function
|
|
|
|
|
|