progress on streaming

This commit is contained in:
Dan Saunders
2025-08-20 03:33:59 +00:00
parent 3b2dd05798
commit 7bb52d00bb
8 changed files with 244 additions and 50 deletions

View File

@@ -94,10 +94,8 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)

View File

@@ -111,14 +111,11 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Calculate total number of training steps
if isinstance(train_dataset, IterableDataset):
# For streaming datasets, we must use max_steps
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
# Streaming case
total_num_steps = cfg.max_steps
else:
# For regular datasets, calculate from dataset size or use max_steps
# Non-streaming case
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
@@ -267,26 +264,11 @@ def _load_tokenized_prepared_datasets(
Returns:
Tuple of (dataset, prompters list).
"""
# Select correct dataset configuration based on split
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw datasets
prompters: list[Prompter | None] = []
if dataset is None:
# For streaming datasets, skip caching and load raw datasets directly
if cfg.streaming:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
@@ -294,6 +276,31 @@ def _load_tokenized_prepared_datasets(
split,
processor,
)
else:
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw
# datasets
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
)
return dataset, prompters
@@ -339,7 +346,6 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
@@ -412,13 +418,12 @@ def _handle_train_dataset_split(
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
int(cfg.val_set_size)
if cfg.val_set_size and cfg.val_set_size > 1
else float(cfg.val_set_size or 0.0)
)
if val_set_size:
if isinstance(dataset, IterableDataset):
LOG.info("Validation splits not supported for streaming datasets, skipping")
return dataset, None
# Create train/validation split
train_dataset, eval_dataset = create_train_validation_split(
dataset, cfg, val_set_size

View File

@@ -13,6 +13,7 @@ from datasets import (
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
interleave_datasets,
load_dataset,
load_from_disk,
)
@@ -524,7 +525,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -542,18 +545,28 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
# Only shuffle regular datasets, not IterableDatasets
if isinstance(ds, IterableDataset):
return ds
return ds.shuffle(seed=cfg.seed)
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
# Check if we have any IterableDatasets
has_iterable = any(isinstance(ds, IterableDataset) for ds in datasets)
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if has_iterable:
LOG.info("Merging streaming datasets...")
merged_dataset = interleave_datasets(datasets, seed=cfg.seed)
else:
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
if cfg.shuffle_merged_datasets:
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
@@ -562,6 +575,9 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
if isinstance(merged_dataset, IterableDataset):
LOG.debug("Skipping shuffle for streaming datasets.")
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset

View File

@@ -100,10 +100,6 @@ def get_dataset_wrapper(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset(

View File

@@ -161,7 +161,12 @@ class HyperparametersConfig(BaseModel):
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(default=1.0)
num_epochs: float = Field(
default=1.0,
json_schema_extra={
"description": "Number of iterations over dataset for training"
},
)
@field_validator("batch_size")
@classmethod

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-boolean-expressions
import json
import os
import sys
import tempfile
from pathlib import Path
@@ -192,6 +193,7 @@ class AttentionValidationMixin:
return data
# pylint: disable=too-many-public-methods
class TrainingValidationMixin:
"""Validation methods related to training configuration."""
@@ -508,11 +510,58 @@ class TrainingValidationMixin:
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used together."
"pretraining_dataset and include_tokens_per_second cannot be used "
"together."
)
return data
@model_validator(mode="before")
@classmethod
def check_max_steps_num_epochs_conflict(cls, data):
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
max_steps = data.get("max_steps")
num_epochs = data.get("num_epochs")
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
if max_steps is None and num_epochs is None:
data["num_epochs"] = 1.0
return data
@model_validator(mode="before")
@classmethod
def check_saves_per_epoch_conflicts(cls, data):
"""Ensure saves_per_epoch is compatible with training configuration."""
saves_per_epoch = data.get("saves_per_epoch")
num_epochs = data.get("num_epochs")
if saves_per_epoch is not None:
# Check if saves_per_epoch is set but num_epochs is unset
if num_epochs is None:
raise ValueError(
"saves_per_epoch requires num_epochs to be set to calculate save "
"intervals."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals_per_epoch_conflicts(cls, data):
"""Ensure evals_per_epoch is compatible with training configuration."""
evals_per_epoch = data.get("evals_per_epoch")
num_epochs = data.get("num_epochs")
if evals_per_epoch is not None:
if num_epochs is None:
raise ValueError(
"evals_per_epoch requires num_epochs to be set to calculate "
"evaluation intervals."
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1336,7 +1385,6 @@ class GRPOVllmValidationMixin:
return self
# pylint: disable=too-many-ancestors
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
@@ -1360,7 +1408,73 @@ class StreamingValidationMixin:
return self
@model_validator(mode="after")
def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0:
raise ValueError(
"Validation splits not supported for streaming datasets, skipping"
)
return self
@model_validator(mode="after")
def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets")
return self
@model_validator(mode="after")
def check_streaming_skip_prepare_dataset(self):
"""Ensure skip_prepare_dataset is set for streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
skip_prepare = getattr(self, "skip_prepare_dataset", None)
if skip_prepare is False:
LOG.warning(
"skip_prepare_dataset=False is not compatible with streaming "
"datasets. Setting skip_prepare_dataset=True."
)
self.skip_prepare_dataset = True
return self
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,

View File

@@ -547,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (