progress on streaming
This commit is contained in:
@@ -94,10 +94,8 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = list(dataset.features.keys())
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
@@ -111,14 +111,11 @@ def _prepare_standard_dataset(
|
||||
"You should set `eval_sample_packing: False` in your config."
|
||||
)
|
||||
|
||||
# Calculate total number of training steps
|
||||
if isinstance(train_dataset, IterableDataset):
|
||||
# For streaming datasets, we must use max_steps
|
||||
if not cfg.max_steps:
|
||||
raise ValueError("max_steps must be set when using streaming datasets")
|
||||
# Streaming case
|
||||
total_num_steps = cfg.max_steps
|
||||
else:
|
||||
# For regular datasets, calculate from dataset size or use max_steps
|
||||
# Non-streaming case
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
@@ -267,26 +264,11 @@ def _load_tokenized_prepared_datasets(
|
||||
Returns:
|
||||
Tuple of (dataset, prompters list).
|
||||
"""
|
||||
# Select correct dataset configuration based on split
|
||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
|
||||
# Generate dataset hash for caching
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
|
||||
# Try loading from hub if push_dataset_to_hub is configured
|
||||
dataset = None
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = try_load_from_hub(cfg, dataset_hash, split)
|
||||
|
||||
# If not found on hub, try loading from disk
|
||||
if dataset is None:
|
||||
dataset = load_preprocessed_dataset(cfg, dataset_hash)
|
||||
|
||||
# If not found on disk or skipping prepared dataset, load and process raw datasets
|
||||
prompters: list[Prompter | None] = []
|
||||
if dataset is None:
|
||||
|
||||
# For streaming datasets, skip caching and load raw datasets directly
|
||||
if cfg.streaming:
|
||||
dataset, prompters = _load_raw_datasets(
|
||||
cfg,
|
||||
datasets_configs,
|
||||
@@ -294,6 +276,31 @@ def _load_tokenized_prepared_datasets(
|
||||
split,
|
||||
processor,
|
||||
)
|
||||
else:
|
||||
# Generate dataset hash for caching
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
|
||||
# Try loading from hub if push_dataset_to_hub is configured
|
||||
dataset = None
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = try_load_from_hub(cfg, dataset_hash, split)
|
||||
|
||||
# If not found on hub, try loading from disk
|
||||
if dataset is None:
|
||||
dataset = load_preprocessed_dataset(cfg, dataset_hash)
|
||||
|
||||
# If not found on disk or skipping prepared dataset, load and process raw
|
||||
# datasets
|
||||
if dataset is None:
|
||||
dataset, prompters = _load_raw_datasets(
|
||||
cfg,
|
||||
datasets_configs,
|
||||
tokenizer,
|
||||
split,
|
||||
processor,
|
||||
)
|
||||
|
||||
return dataset, prompters
|
||||
|
||||
@@ -339,7 +346,6 @@ def _load_raw_datasets(
|
||||
if cfg.sample_packing:
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
# Save the prepared dataset
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
@@ -412,13 +418,12 @@ def _handle_train_dataset_split(
|
||||
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
|
||||
"""Handle processing for train split, including validation set creation."""
|
||||
val_set_size = (
|
||||
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
||||
int(cfg.val_set_size)
|
||||
if cfg.val_set_size and cfg.val_set_size > 1
|
||||
else float(cfg.val_set_size or 0.0)
|
||||
)
|
||||
|
||||
if val_set_size:
|
||||
if isinstance(dataset, IterableDataset):
|
||||
LOG.info("Validation splits not supported for streaming datasets, skipping")
|
||||
return dataset, None
|
||||
# Create train/validation split
|
||||
train_dataset, eval_dataset = create_train_validation_split(
|
||||
dataset, cfg, val_set_size
|
||||
|
||||
@@ -13,6 +13,7 @@ from datasets import (
|
||||
IterableDataset,
|
||||
IterableDatasetDict,
|
||||
concatenate_datasets,
|
||||
interleave_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
@@ -524,7 +525,9 @@ def generate_dataset_hash_from_config(
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
def merge_datasets(
|
||||
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||
) -> Dataset | IterableDataset:
|
||||
"""Merge multiple datasets into one with optional shuffling.
|
||||
|
||||
Args:
|
||||
@@ -542,18 +545,28 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
|
||||
return ds
|
||||
|
||||
# Only shuffle regular datasets, not IterableDatasets
|
||||
if isinstance(ds, IterableDataset):
|
||||
return ds
|
||||
return ds.shuffle(seed=cfg.seed)
|
||||
|
||||
# If enabled, shuffle each dataset independently before merging.
|
||||
# This allows curriculum learning strategies to be applied at the dataset level.
|
||||
if cfg.shuffle_before_merging_datasets:
|
||||
LOG.info("Shuffling each dataset individually before merging...")
|
||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
||||
# Check if we have any IterableDatasets
|
||||
has_iterable = any(isinstance(ds, IterableDataset) for ds in datasets)
|
||||
|
||||
LOG.info("Merging datasets...")
|
||||
merged_dataset = concatenate_datasets(datasets)
|
||||
if has_iterable:
|
||||
LOG.info("Merging streaming datasets...")
|
||||
merged_dataset = interleave_datasets(datasets, seed=cfg.seed)
|
||||
else:
|
||||
# If enabled, shuffle each dataset independently before merging.
|
||||
# This allows curriculum learning strategies to be applied at the dataset level.
|
||||
if cfg.shuffle_before_merging_datasets:
|
||||
LOG.info("Shuffling each dataset individually before merging...")
|
||||
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.info("Merging datasets...")
|
||||
merged_dataset = concatenate_datasets(datasets)
|
||||
|
||||
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
|
||||
LOG.debug("Shuffling merged datasets...")
|
||||
if cfg.curriculum_sampling:
|
||||
LOG.warning(
|
||||
@@ -562,6 +575,9 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
)
|
||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||
else:
|
||||
LOG.debug("Not shuffling merged datasets.")
|
||||
if isinstance(merged_dataset, IterableDataset):
|
||||
LOG.debug("Skipping shuffle for streaming datasets.")
|
||||
else:
|
||||
LOG.debug("Not shuffling merged datasets.")
|
||||
|
||||
return merged_dataset
|
||||
|
||||
@@ -100,10 +100,6 @@ def get_dataset_wrapper(
|
||||
dataset_config, tokenizer, cfg, dataset, dataset_kwargs
|
||||
)
|
||||
|
||||
# Skip preparation if configured
|
||||
if cfg.skip_prepare_dataset:
|
||||
return dataset, None
|
||||
|
||||
# Bradley-Terry dataset
|
||||
if dataset_config.type.startswith("bradley_terry"):
|
||||
return _handle_bradley_terry_dataset(
|
||||
|
||||
@@ -161,7 +161,12 @@ class HyperparametersConfig(BaseModel):
|
||||
max_grad_norm: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
||||
)
|
||||
num_epochs: float = Field(default=1.0)
|
||||
num_epochs: float = Field(
|
||||
default=1.0,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations over dataset for training"
|
||||
},
|
||||
)
|
||||
|
||||
@field_validator("batch_size")
|
||||
@classmethod
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@@ -192,6 +193,7 @@ class AttentionValidationMixin:
|
||||
return data
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TrainingValidationMixin:
|
||||
"""Validation methods related to training configuration."""
|
||||
|
||||
@@ -508,11 +510,58 @@ class TrainingValidationMixin:
|
||||
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
|
||||
# due to trying to count the number of tokens total in the dataset
|
||||
raise ValueError(
|
||||
"pretraining_dataset and include_tokens_per_second cannot be used together."
|
||||
"pretraining_dataset and include_tokens_per_second cannot be used "
|
||||
"together."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_max_steps_num_epochs_conflict(cls, data):
|
||||
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
|
||||
max_steps = data.get("max_steps")
|
||||
num_epochs = data.get("num_epochs")
|
||||
|
||||
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
|
||||
if max_steps is None and num_epochs is None:
|
||||
data["num_epochs"] = 1.0
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_saves_per_epoch_conflicts(cls, data):
|
||||
"""Ensure saves_per_epoch is compatible with training configuration."""
|
||||
saves_per_epoch = data.get("saves_per_epoch")
|
||||
num_epochs = data.get("num_epochs")
|
||||
|
||||
if saves_per_epoch is not None:
|
||||
# Check if saves_per_epoch is set but num_epochs is unset
|
||||
if num_epochs is None:
|
||||
raise ValueError(
|
||||
"saves_per_epoch requires num_epochs to be set to calculate save "
|
||||
"intervals."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_evals_per_epoch_conflicts(cls, data):
|
||||
"""Ensure evals_per_epoch is compatible with training configuration."""
|
||||
evals_per_epoch = data.get("evals_per_epoch")
|
||||
num_epochs = data.get("num_epochs")
|
||||
|
||||
if evals_per_epoch is not None:
|
||||
if num_epochs is None:
|
||||
raise ValueError(
|
||||
"evals_per_epoch requires num_epochs to be set to calculate "
|
||||
"evaluation intervals."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class LoRAValidationMixin:
|
||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||
@@ -1336,7 +1385,6 @@ class GRPOVllmValidationMixin:
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class StreamingValidationMixin:
|
||||
"""Validation methods related to streaming datasets."""
|
||||
|
||||
@@ -1360,7 +1408,73 @@ class StreamingValidationMixin:
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_streaming_validation_splits_conflict(self):
|
||||
"""Ensure validation splits are not used with streaming datasets."""
|
||||
# Check if streaming is explicitly enabled
|
||||
streaming_enabled = getattr(self, "streaming", None) is True
|
||||
|
||||
# Check if pretraining dataset exists (defaults to streaming)
|
||||
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
|
||||
streaming_default_for_pretraining = (
|
||||
has_pretraining and getattr(self, "streaming", None) is None
|
||||
)
|
||||
|
||||
# If streaming is enabled (explicitly or by default for pretraining)
|
||||
if streaming_enabled or streaming_default_for_pretraining:
|
||||
val_set_size = getattr(self, "val_set_size", 0.0)
|
||||
if val_set_size and val_set_size > 0:
|
||||
raise ValueError(
|
||||
"Validation splits not supported for streaming datasets, skipping"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_streaming_preprocessing_conflict(self):
|
||||
"""Ensure preprocessing is not enabled with streaming datasets."""
|
||||
# Check if streaming is explicitly enabled
|
||||
streaming_enabled = getattr(self, "streaming", None) is True
|
||||
|
||||
# Check if pretraining dataset exists (defaults to streaming)
|
||||
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
|
||||
streaming_default_for_pretraining = (
|
||||
has_pretraining and getattr(self, "streaming", None) is None
|
||||
)
|
||||
|
||||
# If streaming is enabled (explicitly or by default for pretraining)
|
||||
if streaming_enabled or streaming_default_for_pretraining:
|
||||
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
|
||||
raise ValueError("preprocess is not supported for streaming datasets")
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_streaming_skip_prepare_dataset(self):
|
||||
"""Ensure skip_prepare_dataset is set for streaming datasets."""
|
||||
# Check if streaming is explicitly enabled
|
||||
streaming_enabled = getattr(self, "streaming", None) is True
|
||||
|
||||
# Check if pretraining dataset exists (defaults to streaming)
|
||||
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
|
||||
streaming_default_for_pretraining = (
|
||||
has_pretraining and getattr(self, "streaming", None) is None
|
||||
)
|
||||
|
||||
# If streaming is enabled (explicitly or by default for pretraining)
|
||||
if streaming_enabled or streaming_default_for_pretraining:
|
||||
skip_prepare = getattr(self, "skip_prepare_dataset", None)
|
||||
if skip_prepare is False:
|
||||
LOG.warning(
|
||||
"skip_prepare_dataset=False is not compatible with streaming "
|
||||
"datasets. Setting skip_prepare_dataset=True."
|
||||
)
|
||||
self.skip_prepare_dataset = True
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=too-many-ancestors
|
||||
class ValidationMixin(
|
||||
DatasetValidationMixin,
|
||||
AttentionValidationMixin,
|
||||
|
||||
@@ -547,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None):
|
||||
if stage == 3:
|
||||
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||
|
||||
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
|
||||
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
|
||||
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
|
||||
# to model load.
|
||||
if (
|
||||
|
||||
@@ -7,13 +7,13 @@ from typing import Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from tests.constants import (
|
||||
@@ -46,6 +46,24 @@ class TestDatasetPreparation:
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def streaming_dataset_fixture(self):
|
||||
"""Create a streaming dataset fixture for testing."""
|
||||
|
||||
def generator():
|
||||
yield {
|
||||
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
|
||||
"input": "He finnished his meal and left the resturant",
|
||||
"output": "He finished his meal and left the restaurant.",
|
||||
}
|
||||
yield {
|
||||
"instruction": "What is the capital of France?",
|
||||
"input": "",
|
||||
"output": "The capital of France is Paris.",
|
||||
}
|
||||
|
||||
return IterableDataset.from_generator(generator)
|
||||
|
||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
||||
@enable_hf_offline
|
||||
def test_load_hub(self, tokenizer):
|
||||
@@ -486,3 +504,45 @@ class TestDatasetPreparation:
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
|
||||
"""Test streaming SFT dataset preparation with IterableDataset."""
|
||||
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
|
||||
mock_load.return_value = streaming_dataset_fixture
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 256,
|
||||
"streaming": True,
|
||||
"max_steps": 100, # Required for streaming datasets
|
||||
"datasets": [
|
||||
{
|
||||
"path": "dummy/path",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
cfg, tokenizer
|
||||
)
|
||||
|
||||
# Verify it returns an IterableDataset
|
||||
assert isinstance(train_dataset, IterableDataset)
|
||||
assert eval_dataset is None # No eval split for streaming
|
||||
assert total_num_steps == 100 # Should use max_steps
|
||||
assert len(prompters) == 1
|
||||
|
||||
# Test that we can iterate through the dataset
|
||||
sample_count = 0
|
||||
for sample in train_dataset:
|
||||
assert "input_ids" in sample
|
||||
assert "attention_mask" in sample
|
||||
assert "labels" in sample
|
||||
sample_count += 1
|
||||
if sample_count >= 2: # Just test first few samples
|
||||
break
|
||||
|
||||
assert sample_count == 2
|
||||
|
||||
Reference in New Issue
Block a user