progress on streaming

This commit is contained in:
Dan Saunders
2025-08-20 03:33:59 +00:00
parent 3b2dd05798
commit 7bb52d00bb
8 changed files with 244 additions and 50 deletions

View File

@@ -94,10 +94,8 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)

View File

@@ -111,14 +111,11 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Calculate total number of training steps
if isinstance(train_dataset, IterableDataset):
# For streaming datasets, we must use max_steps
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
# Streaming case
total_num_steps = cfg.max_steps
else:
# For regular datasets, calculate from dataset size or use max_steps
# Non-streaming case
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
@@ -267,26 +264,11 @@ def _load_tokenized_prepared_datasets(
Returns:
Tuple of (dataset, prompters list).
"""
# Select correct dataset configuration based on split
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw datasets
prompters: list[Prompter | None] = []
if dataset is None:
# For streaming datasets, skip caching and load raw datasets directly
if cfg.streaming:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
@@ -294,6 +276,31 @@ def _load_tokenized_prepared_datasets(
split,
processor,
)
else:
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw
# datasets
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
)
return dataset, prompters
@@ -339,7 +346,6 @@ def _load_raw_datasets(
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
@@ -412,13 +418,12 @@ def _handle_train_dataset_split(
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
int(cfg.val_set_size)
if cfg.val_set_size and cfg.val_set_size > 1
else float(cfg.val_set_size or 0.0)
)
if val_set_size:
if isinstance(dataset, IterableDataset):
LOG.info("Validation splits not supported for streaming datasets, skipping")
return dataset, None
# Create train/validation split
train_dataset, eval_dataset = create_train_validation_split(
dataset, cfg, val_set_size

View File

@@ -13,6 +13,7 @@ from datasets import (
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
interleave_datasets,
load_dataset,
load_from_disk,
)
@@ -524,7 +525,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -542,18 +545,28 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
# Only shuffle regular datasets, not IterableDatasets
if isinstance(ds, IterableDataset):
return ds
return ds.shuffle(seed=cfg.seed)
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
# Check if we have any IterableDatasets
has_iterable = any(isinstance(ds, IterableDataset) for ds in datasets)
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if has_iterable:
LOG.info("Merging streaming datasets...")
merged_dataset = interleave_datasets(datasets, seed=cfg.seed)
else:
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
if cfg.shuffle_merged_datasets:
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
@@ -562,6 +575,9 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
LOG.debug("Not shuffling merged datasets.")
if isinstance(merged_dataset, IterableDataset):
LOG.debug("Skipping shuffle for streaming datasets.")
else:
LOG.debug("Not shuffling merged datasets.")
return merged_dataset

View File

@@ -100,10 +100,6 @@ def get_dataset_wrapper(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
)
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset(

View File

@@ -161,7 +161,12 @@ class HyperparametersConfig(BaseModel):
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(default=1.0)
num_epochs: float = Field(
default=1.0,
json_schema_extra={
"description": "Number of iterations over dataset for training"
},
)
@field_validator("batch_size")
@classmethod

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-boolean-expressions
import json
import os
import sys
import tempfile
from pathlib import Path
@@ -192,6 +193,7 @@ class AttentionValidationMixin:
return data
# pylint: disable=too-many-public-methods
class TrainingValidationMixin:
"""Validation methods related to training configuration."""
@@ -508,11 +510,58 @@ class TrainingValidationMixin:
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used together."
"pretraining_dataset and include_tokens_per_second cannot be used "
"together."
)
return data
@model_validator(mode="before")
@classmethod
def check_max_steps_num_epochs_conflict(cls, data):
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
max_steps = data.get("max_steps")
num_epochs = data.get("num_epochs")
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
if max_steps is None and num_epochs is None:
data["num_epochs"] = 1.0
return data
@model_validator(mode="before")
@classmethod
def check_saves_per_epoch_conflicts(cls, data):
"""Ensure saves_per_epoch is compatible with training configuration."""
saves_per_epoch = data.get("saves_per_epoch")
num_epochs = data.get("num_epochs")
if saves_per_epoch is not None:
# Check if saves_per_epoch is set but num_epochs is unset
if num_epochs is None:
raise ValueError(
"saves_per_epoch requires num_epochs to be set to calculate save "
"intervals."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals_per_epoch_conflicts(cls, data):
"""Ensure evals_per_epoch is compatible with training configuration."""
evals_per_epoch = data.get("evals_per_epoch")
num_epochs = data.get("num_epochs")
if evals_per_epoch is not None:
if num_epochs is None:
raise ValueError(
"evals_per_epoch requires num_epochs to be set to calculate "
"evaluation intervals."
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1336,7 +1385,6 @@ class GRPOVllmValidationMixin:
return self
# pylint: disable=too-many-ancestors
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
@@ -1360,7 +1408,73 @@ class StreamingValidationMixin:
return self
@model_validator(mode="after")
def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0:
raise ValueError(
"Validation splits not supported for streaming datasets, skipping"
)
return self
@model_validator(mode="after")
def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets")
return self
@model_validator(mode="after")
def check_streaming_skip_prepare_dataset(self):
"""Ensure skip_prepare_dataset is set for streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
skip_prepare = getattr(self, "skip_prepare_dataset", None)
if skip_prepare is False:
LOG.warning(
"skip_prepare_dataset=False is not compatible with streaming "
"datasets. Setting skip_prepare_dataset=True."
)
self.skip_prepare_dataset = True
return self
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,

View File

@@ -547,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (

View File

@@ -7,13 +7,13 @@ from typing import Any, Generator
from unittest.mock import patch
import pytest
from datasets import Dataset
from datasets import Dataset, IterableDataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -46,6 +46,24 @@ class TestDatasetPreparation:
]
)
@pytest.fixture
def streaming_dataset_fixture(self):
"""Create a streaming dataset fixture for testing."""
def generator():
yield {
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
yield {
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris.",
}
return IterableDataset.from_generator(generator)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
@@ -486,3 +504,45 @@ class TestDatasetPreparation:
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
"""Test streaming SFT dataset preparation with IterableDataset."""
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
mock_load.return_value = streaming_dataset_fixture
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"streaming": True,
"max_steps": 100, # Required for streaming datasets
"datasets": [
{
"path": "dummy/path",
"type": "alpaca",
},
],
}
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, tokenizer
)
# Verify it returns an IterableDataset
assert isinstance(train_dataset, IterableDataset)
assert eval_dataset is None # No eval split for streaming
assert total_num_steps == 100 # Should use max_steps
assert len(prompters) == 1
# Test that we can iterate through the dataset
sample_count = 0
for sample in train_dataset:
assert "input_ids" in sample
assert "attention_mask" in sample
assert "labels" in sample
sample_count += 1
if sample_count >= 2: # Just test first few samples
break
assert sample_count == 2