separate streaming and pretraining

This commit is contained in:
Dan Saunders
2025-08-19 18:05:05 +00:00
parent ab4d604a8f
commit 16ff01df85
5 changed files with 139 additions and 27 deletions

View File

@@ -43,7 +43,13 @@ class TokenizedPromptDataset(Dataset):
)
def process(self, dataset):
features = dataset.features.keys()
# Handle both regular Dataset and IterableDataset
if hasattr(dataset, "features") and dataset.features:
features = dataset.features.keys()
else:
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -54,19 +60,30 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs = {"desc": "Strategy Filtering Rows"}
# Only add num_proc for regular datasets
if features is not None:
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
**filter_kwargs,
)
map_kwargs_final = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if features is not None:
map_kwargs_final["remove_columns"] = features
map_kwargs_final["num_proc"] = self.process_count
map_kwargs_final["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
**map_kwargs_final,
)

View File

@@ -9,6 +9,7 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -43,6 +44,18 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _determine_streaming_mode(cfg: DictDefault) -> bool:
"""Determine if we should use streaming mode based on config."""
if cfg.streaming is not None:
return cfg.streaming
# Default to streaming for pretraining datasets
if cfg.pretraining_dataset:
return True
return False
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
@@ -61,6 +74,13 @@ def prepare_datasets(
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
# Determine streaming mode from config
streaming_mode = _determine_streaming_mode(cfg)
# Override preprocess_iterable parameter with streaming config
if streaming_mode:
preprocess_iterable = True
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
@@ -118,12 +138,19 @@ def _prepare_standard_dataset(
)
# Calculate total number of training steps
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
if isinstance(train_dataset, IterableDataset):
# For streaming datasets, we must use max_steps
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
total_num_steps = cfg.max_steps
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
# For regular datasets, calculate from dataset size or use max_steps
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters
@@ -373,7 +400,7 @@ def _load_and_process_single_dataset(
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, DatasetDict):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -418,14 +445,17 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
def _handle_train_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[Dataset, Dataset | None]:
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)
if val_set_size:
if isinstance(dataset, IterableDataset):
LOG.info("Validation splits not supported for streaming datasets, skipping")
return dataset, None
# Create train/validation split
train_dataset, eval_dataset = create_train_validation_split(
dataset, cfg, val_set_size
@@ -433,27 +463,33 @@ def _handle_train_dataset_split(
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication:
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[None, Dataset | IterableDataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication:
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
eval_dataset = dataset
return None, eval_dataset
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
def _apply_dataset_sharding(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> Dataset | IterableDataset:
"""Apply dataset sharding if configured.
Args:

View File

@@ -190,12 +190,18 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
else:
# For IterableDataset, we can't check columns upfront, so skip for streaming
if isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
return dataset
drop_long = functools.partial(
drop_long_seq,

View File

@@ -932,6 +932,34 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
streaming_dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
},
)
streaming_mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
},
)
streaming_buffer_per_dataset: int | None = Field(
default=1000,
json_schema_extra={
"description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None

View File

@@ -1337,6 +1337,30 @@ class GRPOVllmValidationMixin:
# pylint: disable=too-many-ancestors
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
@@ -1347,6 +1371,7 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,