From 7bb52d00bb5af9e13a465d99f0a19919f1f8db92 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 20 Aug 2025 03:33:59 +0000 Subject: [PATCH] progress on streaming --- src/axolotl/datasets.py | 2 - src/axolotl/utils/data/sft.py | 61 ++++++------ src/axolotl/utils/data/shared.py | 36 ++++++-- src/axolotl/utils/data/wrappers.py | 4 - src/axolotl/utils/schemas/training.py | 7 +- src/axolotl/utils/schemas/validation.py | 118 +++++++++++++++++++++++- src/axolotl/utils/trainer.py | 2 +- tests/test_datasets.py | 64 ++++++++++++- 8 files changed, 244 insertions(+), 50 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index e3bbc8cf7..0979171f7 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -94,10 +94,8 @@ def wrap_dataset_for_tokenized_prompt( map_kwargs = {} if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - features = list(dataset.features.keys()) return dataset.map( prompt_tokenizer.tokenize_prompt, - remove_columns=features, **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ca69e1d9f..38f8d6d4a 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -111,14 +111,11 @@ def _prepare_standard_dataset( "You should set `eval_sample_packing: False` in your config." ) - # Calculate total number of training steps if isinstance(train_dataset, IterableDataset): - # For streaming datasets, we must use max_steps - if not cfg.max_steps: - raise ValueError("max_steps must be set when using streaming datasets") + # Streaming case total_num_steps = cfg.max_steps else: - # For regular datasets, calculate from dataset size or use max_steps + # Non-streaming case if cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps @@ -267,26 +264,11 @@ def _load_tokenized_prepared_datasets( Returns: Tuple of (dataset, prompters list). """ - # Select correct dataset configuration based on split datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets - - # Generate dataset hash for caching - dataset_hash = generate_dataset_hash_from_config( - cfg, datasets_configs, tokenizer.name_or_path - ) - - # Try loading from hub if push_dataset_to_hub is configured - dataset = None - if cfg.push_dataset_to_hub: - dataset = try_load_from_hub(cfg, dataset_hash, split) - - # If not found on hub, try loading from disk - if dataset is None: - dataset = load_preprocessed_dataset(cfg, dataset_hash) - - # If not found on disk or skipping prepared dataset, load and process raw datasets prompters: list[Prompter | None] = [] - if dataset is None: + + # For streaming datasets, skip caching and load raw datasets directly + if cfg.streaming: dataset, prompters = _load_raw_datasets( cfg, datasets_configs, @@ -294,6 +276,31 @@ def _load_tokenized_prepared_datasets( split, processor, ) + else: + # Generate dataset hash for caching + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + + # Try loading from hub if push_dataset_to_hub is configured + dataset = None + if cfg.push_dataset_to_hub: + dataset = try_load_from_hub(cfg, dataset_hash, split) + + # If not found on hub, try loading from disk + if dataset is None: + dataset = load_preprocessed_dataset(cfg, dataset_hash) + + # If not found on disk or skipping prepared dataset, load and process raw + # datasets + if dataset is None: + dataset, prompters = _load_raw_datasets( + cfg, + datasets_configs, + tokenizer, + split, + processor, + ) return dataset, prompters @@ -339,7 +346,6 @@ def _load_raw_datasets( if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) - # Save the prepared dataset dataset_hash = generate_dataset_hash_from_config( cfg, datasets_configs, tokenizer.name_or_path ) @@ -412,13 +418,12 @@ def _handle_train_dataset_split( ) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]: """Handle processing for train split, including validation set creation.""" val_set_size = ( - int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) + int(cfg.val_set_size) + if cfg.val_set_size and cfg.val_set_size > 1 + else float(cfg.val_set_size or 0.0) ) if val_set_size: - if isinstance(dataset, IterableDataset): - LOG.info("Validation splits not supported for streaming datasets, skipping") - return dataset, None # Create train/validation split train_dataset, eval_dataset = create_train_validation_split( dataset, cfg, val_set_size diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 21c8e472b..bde9e94e1 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -13,6 +13,7 @@ from datasets import ( IterableDataset, IterableDatasetDict, concatenate_datasets, + interleave_datasets, load_dataset, load_from_disk, ) @@ -524,7 +525,9 @@ def generate_dataset_hash_from_config( return str(md5(config_str)) -def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: +def merge_datasets( + datasets: list[Dataset | IterableDataset], cfg: DictDefault +) -> Dataset | IterableDataset: """Merge multiple datasets into one with optional shuffling. Args: @@ -542,18 +545,28 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets: return ds + # Only shuffle regular datasets, not IterableDatasets + if isinstance(ds, IterableDataset): + return ds return ds.shuffle(seed=cfg.seed) - # If enabled, shuffle each dataset independently before merging. - # This allows curriculum learning strategies to be applied at the dataset level. - if cfg.shuffle_before_merging_datasets: - LOG.info("Shuffling each dataset individually before merging...") - datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] + # Check if we have any IterableDatasets + has_iterable = any(isinstance(ds, IterableDataset) for ds in datasets) - LOG.info("Merging datasets...") - merged_dataset = concatenate_datasets(datasets) + if has_iterable: + LOG.info("Merging streaming datasets...") + merged_dataset = interleave_datasets(datasets, seed=cfg.seed) + else: + # If enabled, shuffle each dataset independently before merging. + # This allows curriculum learning strategies to be applied at the dataset level. + if cfg.shuffle_before_merging_datasets: + LOG.info("Shuffling each dataset individually before merging...") + datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets] - if cfg.shuffle_merged_datasets: + LOG.info("Merging datasets...") + merged_dataset = concatenate_datasets(datasets) + + if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset): LOG.debug("Shuffling merged datasets...") if cfg.curriculum_sampling: LOG.warning( @@ -562,6 +575,9 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: ) merged_dataset = merged_dataset.shuffle(seed=cfg.seed) else: - LOG.debug("Not shuffling merged datasets.") + if isinstance(merged_dataset, IterableDataset): + LOG.debug("Skipping shuffle for streaming datasets.") + else: + LOG.debug("Not shuffling merged datasets.") return merged_dataset diff --git a/src/axolotl/utils/data/wrappers.py b/src/axolotl/utils/data/wrappers.py index b6dc42c71..0636d6dd9 100644 --- a/src/axolotl/utils/data/wrappers.py +++ b/src/axolotl/utils/data/wrappers.py @@ -100,10 +100,6 @@ def get_dataset_wrapper( dataset_config, tokenizer, cfg, dataset, dataset_kwargs ) - # Skip preparation if configured - if cfg.skip_prepare_dataset: - return dataset, None - # Bradley-Terry dataset if dataset_config.type.startswith("bradley_terry"): return _handle_bradley_terry_dataset( diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index b1788dcaa..f282f9533 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -161,7 +161,12 @@ class HyperparametersConfig(BaseModel): max_grad_norm: float | None = Field( default=None, json_schema_extra={"description": "Gradient clipping max norm"} ) - num_epochs: float = Field(default=1.0) + num_epochs: float = Field( + default=1.0, + json_schema_extra={ + "description": "Number of iterations over dataset for training" + }, + ) @field_validator("batch_size") @classmethod diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 8303d306a..64fca363c 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -3,6 +3,7 @@ # pylint: disable=too-many-boolean-expressions import json +import os import sys import tempfile from pathlib import Path @@ -192,6 +193,7 @@ class AttentionValidationMixin: return data +# pylint: disable=too-many-public-methods class TrainingValidationMixin: """Validation methods related to training configuration.""" @@ -508,11 +510,58 @@ class TrainingValidationMixin: # combining these would raise `TypeError: cannot pickle 'dict_keys' object` # due to trying to count the number of tokens total in the dataset raise ValueError( - "pretraining_dataset and include_tokens_per_second cannot be used together." + "pretraining_dataset and include_tokens_per_second cannot be used " + "together." ) return data + @model_validator(mode="before") + @classmethod + def check_max_steps_num_epochs_conflict(cls, data): + """Handle max_steps and num_epochs configuration and auto-set defaults.""" + max_steps = data.get("max_steps") + num_epochs = data.get("num_epochs") + + # Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set + if max_steps is None and num_epochs is None: + data["num_epochs"] = 1.0 + + return data + + @model_validator(mode="before") + @classmethod + def check_saves_per_epoch_conflicts(cls, data): + """Ensure saves_per_epoch is compatible with training configuration.""" + saves_per_epoch = data.get("saves_per_epoch") + num_epochs = data.get("num_epochs") + + if saves_per_epoch is not None: + # Check if saves_per_epoch is set but num_epochs is unset + if num_epochs is None: + raise ValueError( + "saves_per_epoch requires num_epochs to be set to calculate save " + "intervals." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_evals_per_epoch_conflicts(cls, data): + """Ensure evals_per_epoch is compatible with training configuration.""" + evals_per_epoch = data.get("evals_per_epoch") + num_epochs = data.get("num_epochs") + + if evals_per_epoch is not None: + if num_epochs is None: + raise ValueError( + "evals_per_epoch requires num_epochs to be set to calculate " + "evaluation intervals." + ) + + return data + class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration.""" @@ -1336,7 +1385,6 @@ class GRPOVllmValidationMixin: return self -# pylint: disable=too-many-ancestors class StreamingValidationMixin: """Validation methods related to streaming datasets.""" @@ -1360,7 +1408,73 @@ class StreamingValidationMixin: return self + @model_validator(mode="after") + def check_streaming_validation_splits_conflict(self): + """Ensure validation splits are not used with streaming datasets.""" + # Check if streaming is explicitly enabled + streaming_enabled = getattr(self, "streaming", None) is True + # Check if pretraining dataset exists (defaults to streaming) + has_pretraining = getattr(self, "pretraining_dataset", None) is not None + streaming_default_for_pretraining = ( + has_pretraining and getattr(self, "streaming", None) is None + ) + + # If streaming is enabled (explicitly or by default for pretraining) + if streaming_enabled or streaming_default_for_pretraining: + val_set_size = getattr(self, "val_set_size", 0.0) + if val_set_size and val_set_size > 0: + raise ValueError( + "Validation splits not supported for streaming datasets, skipping" + ) + + return self + + @model_validator(mode="after") + def check_streaming_preprocessing_conflict(self): + """Ensure preprocessing is not enabled with streaming datasets.""" + # Check if streaming is explicitly enabled + streaming_enabled = getattr(self, "streaming", None) is True + + # Check if pretraining dataset exists (defaults to streaming) + has_pretraining = getattr(self, "pretraining_dataset", None) is not None + streaming_default_for_pretraining = ( + has_pretraining and getattr(self, "streaming", None) is None + ) + + # If streaming is enabled (explicitly or by default for pretraining) + if streaming_enabled or streaming_default_for_pretraining: + if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": + raise ValueError("preprocess is not supported for streaming datasets") + + return self + + @model_validator(mode="after") + def check_streaming_skip_prepare_dataset(self): + """Ensure skip_prepare_dataset is set for streaming datasets.""" + # Check if streaming is explicitly enabled + streaming_enabled = getattr(self, "streaming", None) is True + + # Check if pretraining dataset exists (defaults to streaming) + has_pretraining = getattr(self, "pretraining_dataset", None) is not None + streaming_default_for_pretraining = ( + has_pretraining and getattr(self, "streaming", None) is None + ) + + # If streaming is enabled (explicitly or by default for pretraining) + if streaming_enabled or streaming_default_for_pretraining: + skip_prepare = getattr(self, "skip_prepare_dataset", None) + if skip_prepare is False: + LOG.warning( + "skip_prepare_dataset=False is not compatible with streaming " + "datasets. Setting skip_prepare_dataset=True." + ) + self.skip_prepare_dataset = True + + return self + + +# pylint: disable=too-many-ancestors class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index e424cb55a..005312733 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -547,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None): if stage == 3: os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true" - # NOTE(djsaunde): The distribued state cannot be initialized prior to the + # NOTE(djsaunde): The distributed state cannot be initialized prior to the # ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior # to model load. if ( diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 719dfdc19..7d371bb41 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -7,13 +7,13 @@ from typing import Any, Generator from unittest.mock import patch import pytest -from datasets import Dataset +from datasets import Dataset, IterableDataset from huggingface_hub import snapshot_download from transformers import PreTrainedTokenizer from axolotl.loaders.tokenizer import load_tokenizer from axolotl.utils.data.rl import prepare_preference_datasets -from axolotl.utils.data.sft import _load_tokenized_prepared_datasets +from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets from axolotl.utils.dict import DictDefault from tests.constants import ( @@ -46,6 +46,24 @@ class TestDatasetPreparation: ] ) + @pytest.fixture + def streaming_dataset_fixture(self): + """Create a streaming dataset fixture for testing.""" + + def generator(): + yield { + "instruction": "Evaluate this sentence for spelling and grammar mistakes", + "input": "He finnished his meal and left the resturant", + "output": "He finished his meal and left the restaurant.", + } + yield { + "instruction": "What is the capital of France?", + "input": "", + "output": "The capital of France is Paris.", + } + + return IterableDataset.from_generator(generator) + @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") @enable_hf_offline def test_load_hub(self, tokenizer): @@ -486,3 +504,45 @@ class TestDatasetPreparation: assert "attention_mask" in dataset.features assert "labels" in dataset.features shutil.rmtree(tmp_ds_path) + + def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture): + """Test streaming SFT dataset preparation with IterableDataset.""" + with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load: + mock_load.return_value = streaming_dataset_fixture + + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + "sequence_len": 256, + "streaming": True, + "max_steps": 100, # Required for streaming datasets + "datasets": [ + { + "path": "dummy/path", + "type": "alpaca", + }, + ], + } + ) + + train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets( + cfg, tokenizer + ) + + # Verify it returns an IterableDataset + assert isinstance(train_dataset, IterableDataset) + assert eval_dataset is None # No eval split for streaming + assert total_num_steps == 100 # Should use max_steps + assert len(prompters) == 1 + + # Test that we can iterate through the dataset + sample_count = 0 + for sample in train_dataset: + assert "input_ids" in sample + assert "attention_mask" in sample + assert "labels" in sample + sample_count += 1 + if sample_count >= 2: # Just test first few samples + break + + assert sample_count == 2