separate streaming and pretraining
This commit is contained in:
@@ -43,7 +43,13 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
# Handle both regular Dataset and IterableDataset
|
||||||
|
if hasattr(dataset, "features") and dataset.features:
|
||||||
|
features = dataset.features.keys()
|
||||||
|
else:
|
||||||
|
# For IterableDataset, we can't access features upfront
|
||||||
|
# We'll need to infer from the first batch
|
||||||
|
features = None
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -54,19 +60,30 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
hasattr(self.prompt_tokenizer, "filter_rows")
|
hasattr(self.prompt_tokenizer, "filter_rows")
|
||||||
and self.prompt_tokenizer.filter_rows
|
and self.prompt_tokenizer.filter_rows
|
||||||
):
|
):
|
||||||
|
filter_kwargs = {"desc": "Strategy Filtering Rows"}
|
||||||
|
# Only add num_proc for regular datasets
|
||||||
|
if features is not None:
|
||||||
|
filter_kwargs["num_proc"] = self.process_count
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
num_proc=self.process_count,
|
**filter_kwargs,
|
||||||
desc="Strategy Filtering Rows",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
map_kwargs_final = {
|
||||||
|
**map_kwargs,
|
||||||
|
"desc": "Tokenizing Prompts",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add remove_columns for regular datasets
|
||||||
|
if features is not None:
|
||||||
|
map_kwargs_final["remove_columns"] = features
|
||||||
|
map_kwargs_final["num_proc"] = self.process_count
|
||||||
|
map_kwargs_final["keep_in_memory"] = self.keep_in_memory
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
num_proc=self.process_count,
|
**map_kwargs_final,
|
||||||
remove_columns=features,
|
|
||||||
keep_in_memory=self.keep_in_memory,
|
|
||||||
desc="Tokenizing Prompts",
|
|
||||||
**map_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from datasets import (
|
|||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
IterableDataset,
|
IterableDataset,
|
||||||
|
IterableDatasetDict,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
)
|
)
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
@@ -43,6 +44,18 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_streaming_mode(cfg: DictDefault) -> bool:
|
||||||
|
"""Determine if we should use streaming mode based on config."""
|
||||||
|
if cfg.streaming is not None:
|
||||||
|
return cfg.streaming
|
||||||
|
|
||||||
|
# Default to streaming for pretraining datasets
|
||||||
|
if cfg.pretraining_dataset:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_datasets(
|
def prepare_datasets(
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
@@ -61,6 +74,13 @@ def prepare_datasets(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
||||||
"""
|
"""
|
||||||
|
# Determine streaming mode from config
|
||||||
|
streaming_mode = _determine_streaming_mode(cfg)
|
||||||
|
|
||||||
|
# Override preprocess_iterable parameter with streaming config
|
||||||
|
if streaming_mode:
|
||||||
|
preprocess_iterable = True
|
||||||
|
|
||||||
if cfg.pretraining_dataset:
|
if cfg.pretraining_dataset:
|
||||||
return _prepare_pretraining_dataset(
|
return _prepare_pretraining_dataset(
|
||||||
cfg, tokenizer, processor, preprocess_iterable
|
cfg, tokenizer, processor, preprocess_iterable
|
||||||
@@ -118,12 +138,19 @@ def _prepare_standard_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate total number of training steps
|
# Calculate total number of training steps
|
||||||
if cfg.max_steps:
|
if isinstance(train_dataset, IterableDataset):
|
||||||
total_num_steps = min(
|
# For streaming datasets, we must use max_steps
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
if not cfg.max_steps:
|
||||||
)
|
raise ValueError("max_steps must be set when using streaming datasets")
|
||||||
|
total_num_steps = cfg.max_steps
|
||||||
else:
|
else:
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
# For regular datasets, calculate from dataset size or use max_steps
|
||||||
|
if cfg.max_steps:
|
||||||
|
total_num_steps = min(
|
||||||
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||||
|
|
||||||
@@ -373,7 +400,7 @@ def _load_and_process_single_dataset(
|
|||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|
||||||
# Select the appropriate split
|
# Select the appropriate split
|
||||||
if isinstance(dataset, DatasetDict):
|
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
|
||||||
if dataset_config.split and dataset_config.split in dataset:
|
if dataset_config.split and dataset_config.split in dataset:
|
||||||
dataset = dataset[dataset_config.split]
|
dataset = dataset[dataset_config.split]
|
||||||
elif split in dataset:
|
elif split in dataset:
|
||||||
@@ -418,14 +445,17 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
|
|||||||
|
|
||||||
|
|
||||||
def _handle_train_dataset_split(
|
def _handle_train_dataset_split(
|
||||||
dataset: Dataset, cfg: DictDefault
|
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||||
) -> tuple[Dataset, Dataset | None]:
|
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
|
||||||
"""Handle processing for train split, including validation set creation."""
|
"""Handle processing for train split, including validation set creation."""
|
||||||
val_set_size = (
|
val_set_size = (
|
||||||
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
if val_set_size:
|
if val_set_size:
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info("Validation splits not supported for streaming datasets, skipping")
|
||||||
|
return dataset, None
|
||||||
# Create train/validation split
|
# Create train/validation split
|
||||||
train_dataset, eval_dataset = create_train_validation_split(
|
train_dataset, eval_dataset = create_train_validation_split(
|
||||||
dataset, cfg, val_set_size
|
dataset, cfg, val_set_size
|
||||||
@@ -433,27 +463,33 @@ def _handle_train_dataset_split(
|
|||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
# No validation split - apply deduplication if needed and return as train dataset
|
# No validation split - apply deduplication if needed and return as train dataset
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
|
||||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||||
else:
|
else:
|
||||||
|
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
|
||||||
train_dataset = dataset
|
train_dataset = dataset
|
||||||
|
|
||||||
return train_dataset, None
|
return train_dataset, None
|
||||||
|
|
||||||
|
|
||||||
def _handle_test_dataset_split(
|
def _handle_test_dataset_split(
|
||||||
dataset: Dataset, cfg: DictDefault
|
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||||
) -> tuple[None, Dataset | None]:
|
) -> tuple[None, Dataset | IterableDataset | None]:
|
||||||
"""Handle processing for test split."""
|
"""Handle processing for test split."""
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
|
||||||
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||||
else:
|
else:
|
||||||
|
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
|
||||||
eval_dataset = dataset
|
eval_dataset = dataset
|
||||||
|
|
||||||
return None, eval_dataset
|
return None, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
def _apply_dataset_sharding(
|
||||||
|
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||||
|
) -> Dataset | IterableDataset:
|
||||||
"""Apply dataset sharding if configured.
|
"""Apply dataset sharding if configured.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -190,12 +190,18 @@ def handle_long_seq_in_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences removed.
|
||||||
"""
|
"""
|
||||||
if "input_ids" not in dataset.column_names:
|
if hasattr(dataset, "column_names") and dataset.column_names:
|
||||||
LOG.warning(
|
if "input_ids" not in dataset.column_names:
|
||||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
LOG.warning(
|
||||||
"expected for reward modeling."
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||||
)
|
"expected for reward modeling."
|
||||||
return dataset
|
)
|
||||||
|
return dataset
|
||||||
|
else:
|
||||||
|
# For IterableDataset, we can't check columns upfront, so skip for streaming
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
|
||||||
|
return dataset
|
||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
|
|||||||
@@ -932,6 +932,34 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
fix_untrained_tokens: int | list[int] | None = None
|
fix_untrained_tokens: int | list[int] | None = None
|
||||||
|
|
||||||
|
streaming: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_dataset_mixing_strategy: str | None = Field(
|
||||||
|
default="round_robin",
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_mixing_weights: list[float] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_buffer_per_dataset: int | None = Field(
|
||||||
|
default=1000,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: bool | None = None
|
is_preprocess: bool | None = None
|
||||||
preprocess_iterable: bool | None = None
|
preprocess_iterable: bool | None = None
|
||||||
|
|||||||
@@ -1337,6 +1337,30 @@ class GRPOVllmValidationMixin:
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-ancestors
|
# pylint: disable=too-many-ancestors
|
||||||
|
class StreamingValidationMixin:
|
||||||
|
"""Validation methods related to streaming datasets."""
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_streaming_requires_max_steps(self):
|
||||||
|
"""Ensure max_steps is set when using streaming datasets."""
|
||||||
|
# Check if streaming is explicitly enabled
|
||||||
|
streaming_enabled = getattr(self, "streaming", None) is True
|
||||||
|
|
||||||
|
# Check if pretraining dataset exists (defaults to streaming)
|
||||||
|
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
|
||||||
|
streaming_default_for_pretraining = (
|
||||||
|
has_pretraining and getattr(self, "streaming", None) is None
|
||||||
|
)
|
||||||
|
|
||||||
|
# If streaming is enabled (explicitly or by default for pretraining)
|
||||||
|
if streaming_enabled or streaming_default_for_pretraining:
|
||||||
|
max_steps = getattr(self, "max_steps", None)
|
||||||
|
if not max_steps:
|
||||||
|
raise ValueError("max_steps must be set when using streaming datasets")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class ValidationMixin(
|
class ValidationMixin(
|
||||||
DatasetValidationMixin,
|
DatasetValidationMixin,
|
||||||
AttentionValidationMixin,
|
AttentionValidationMixin,
|
||||||
@@ -1347,6 +1371,7 @@ class ValidationMixin(
|
|||||||
SystemValidationMixin,
|
SystemValidationMixin,
|
||||||
ChatTemplateValidationMixin,
|
ChatTemplateValidationMixin,
|
||||||
PretrainingValidationMixin,
|
PretrainingValidationMixin,
|
||||||
|
StreamingValidationMixin,
|
||||||
ModelCompatibilityValidationMixin,
|
ModelCompatibilityValidationMixin,
|
||||||
ComplexValidationMixin,
|
ComplexValidationMixin,
|
||||||
GRPOVllmValidationMixin,
|
GRPOVllmValidationMixin,
|
||||||
|
|||||||
Reference in New Issue
Block a user