From 16ff01df85c69903fa6228bb2adb573214c24cd6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 19 Aug 2025 18:05:05 +0000 Subject: [PATCH] separate streaming and pretraining --- src/axolotl/datasets.py | 33 +++++++++---- src/axolotl/utils/data/sft.py | 62 +++++++++++++++++++------ src/axolotl/utils/data/utils.py | 18 ++++--- src/axolotl/utils/schemas/config.py | 28 +++++++++++ src/axolotl/utils/schemas/validation.py | 25 ++++++++++ 5 files changed, 139 insertions(+), 27 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index c9d006ac8..2d20de4d3 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -43,7 +43,13 @@ class TokenizedPromptDataset(Dataset): ) def process(self, dataset): - features = dataset.features.keys() + # Handle both regular Dataset and IterableDataset + if hasattr(dataset, "features") and dataset.features: + features = dataset.features.keys() + else: + # For IterableDataset, we can't access features upfront + # We'll need to infer from the first batch + features = None map_kwargs = {} if self.prompt_tokenizer.supports_batched: @@ -54,19 +60,30 @@ class TokenizedPromptDataset(Dataset): hasattr(self.prompt_tokenizer, "filter_rows") and self.prompt_tokenizer.filter_rows ): + filter_kwargs = {"desc": "Strategy Filtering Rows"} + # Only add num_proc for regular datasets + if features is not None: + filter_kwargs["num_proc"] = self.process_count + dataset = dataset.filter( self.prompt_tokenizer.filter_rows, - num_proc=self.process_count, - desc="Strategy Filtering Rows", + **filter_kwargs, ) + map_kwargs_final = { + **map_kwargs, + "desc": "Tokenizing Prompts", + } + + # Only add remove_columns for regular datasets + if features is not None: + map_kwargs_final["remove_columns"] = features + map_kwargs_final["num_proc"] = self.process_count + map_kwargs_final["keep_in_memory"] = self.keep_in_memory + return dataset.map( self.prompt_tokenizer.tokenize_prompt, - num_proc=self.process_count, - remove_columns=features, - keep_in_memory=self.keep_in_memory, - desc="Tokenizing Prompts", - **map_kwargs, + **map_kwargs_final, ) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 2ae7d9052..17afba9c2 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -9,6 +9,7 @@ from datasets import ( Dataset, DatasetDict, IterableDataset, + IterableDatasetDict, load_dataset, ) from transformers import PreTrainedTokenizer, ProcessorMixin @@ -43,6 +44,18 @@ from axolotl.utils.trainer import ( LOG = get_logger(__name__) +def _determine_streaming_mode(cfg: DictDefault) -> bool: + """Determine if we should use streaming mode based on config.""" + if cfg.streaming is not None: + return cfg.streaming + + # Default to streaming for pretraining datasets + if cfg.pretraining_dataset: + return True + + return False + + @retry_on_request_exceptions(max_retries=3, delay=5) def prepare_datasets( cfg: DictDefault, @@ -61,6 +74,13 @@ def prepare_datasets( Returns: Tuple of (train_dataset, eval_dataset, total_steps, prompters). """ + # Determine streaming mode from config + streaming_mode = _determine_streaming_mode(cfg) + + # Override preprocess_iterable parameter with streaming config + if streaming_mode: + preprocess_iterable = True + if cfg.pretraining_dataset: return _prepare_pretraining_dataset( cfg, tokenizer, processor, preprocess_iterable @@ -118,12 +138,19 @@ def _prepare_standard_dataset( ) # Calculate total number of training steps - if cfg.max_steps: - total_num_steps = min( - calculate_total_num_steps(cfg, train_dataset), cfg.max_steps - ) + if isinstance(train_dataset, IterableDataset): + # For streaming datasets, we must use max_steps + if not cfg.max_steps: + raise ValueError("max_steps must be set when using streaming datasets") + total_num_steps = cfg.max_steps else: - total_num_steps = calculate_total_num_steps(cfg, train_dataset) + # For regular datasets, calculate from dataset size or use max_steps + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset), cfg.max_steps + ) + else: + total_num_steps = calculate_total_num_steps(cfg, train_dataset) LOG.info(f"Maximum number of steps set at {total_num_steps}") return train_dataset, eval_dataset, total_num_steps, prompters @@ -373,7 +400,7 @@ def _load_and_process_single_dataset( d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) # Select the appropriate split - if isinstance(dataset, DatasetDict): + if isinstance(dataset, (DatasetDict, IterableDatasetDict)): if dataset_config.split and dataset_config.split in dataset: dataset = dataset[dataset_config.split] elif split in dataset: @@ -418,14 +445,17 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]: def _handle_train_dataset_split( - dataset: Dataset, cfg: DictDefault -) -> tuple[Dataset, Dataset | None]: + dataset: Dataset | IterableDataset, cfg: DictDefault +) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]: """Handle processing for train split, including validation set creation.""" val_set_size = ( int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) ) if val_set_size: + if isinstance(dataset, IterableDataset): + LOG.info("Validation splits not supported for streaming datasets, skipping") + return dataset, None # Create train/validation split train_dataset, eval_dataset = create_train_validation_split( dataset, cfg, val_set_size @@ -433,27 +463,33 @@ def _handle_train_dataset_split( return train_dataset, eval_dataset # No validation split - apply deduplication if needed and return as train dataset - if cfg.dataset_exact_deduplication: + if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset): train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) else: + if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset): + LOG.info("Deduplication skipped for streaming datasets (not compatible)") train_dataset = dataset return train_dataset, None def _handle_test_dataset_split( - dataset: Dataset, cfg: DictDefault -) -> tuple[None, Dataset | None]: + dataset: Dataset | IterableDataset, cfg: DictDefault +) -> tuple[None, Dataset | IterableDataset | None]: """Handle processing for test split.""" - if cfg.dataset_exact_deduplication: + if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset): eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) else: + if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset): + LOG.info("Deduplication skipped for streaming datasets (not compatible)") eval_dataset = dataset return None, eval_dataset -def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: +def _apply_dataset_sharding( + dataset: Dataset | IterableDataset, cfg: DictDefault +) -> Dataset | IterableDataset: """Apply dataset sharding if configured. Args: diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 856a609c7..d5fa54196 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -190,12 +190,18 @@ def handle_long_seq_in_dataset( Returns: Filtered dataset with long sequences removed. """ - if "input_ids" not in dataset.column_names: - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " - "expected for reward modeling." - ) - return dataset + if hasattr(dataset, "column_names") and dataset.column_names: + if "input_ids" not in dataset.column_names: + LOG.warning( + "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " + "expected for reward modeling." + ) + return dataset + else: + # For IterableDataset, we can't check columns upfront, so skip for streaming + if isinstance(dataset, IterableDataset): + LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)") + return dataset drop_long = functools.partial( drop_long_seq, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index a607b3dca..4f88ab6ea 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -932,6 +932,34 @@ class AxolotlInputConfig( fix_untrained_tokens: int | list[int] | None = None + streaming: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False." + }, + ) + + streaming_dataset_mixing_strategy: str | None = Field( + default="round_robin", + json_schema_extra={ + "description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)." + }, + ) + + streaming_mixing_weights: list[float] | None = Field( + default=None, + json_schema_extra={ + "description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'." + }, + ) + + streaming_buffer_per_dataset: int | None = Field( + default=1000, + json_schema_extra={ + "description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory." + }, + ) + # INTERNALS - document for now, generally not set externally is_preprocess: bool | None = None preprocess_iterable: bool | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 217244b01..8303d306a 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1337,6 +1337,30 @@ class GRPOVllmValidationMixin: # pylint: disable=too-many-ancestors +class StreamingValidationMixin: + """Validation methods related to streaming datasets.""" + + @model_validator(mode="after") + def check_streaming_requires_max_steps(self): + """Ensure max_steps is set when using streaming datasets.""" + # Check if streaming is explicitly enabled + streaming_enabled = getattr(self, "streaming", None) is True + + # Check if pretraining dataset exists (defaults to streaming) + has_pretraining = getattr(self, "pretraining_dataset", None) is not None + streaming_default_for_pretraining = ( + has_pretraining and getattr(self, "streaming", None) is None + ) + + # If streaming is enabled (explicitly or by default for pretraining) + if streaming_enabled or streaming_default_for_pretraining: + max_steps = getattr(self, "max_steps", None) + if not max_steps: + raise ValueError("max_steps must be set when using streaming datasets") + + return self + + class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, @@ -1347,6 +1371,7 @@ class ValidationMixin( SystemValidationMixin, ChatTemplateValidationMixin, PretrainingValidationMixin, + StreamingValidationMixin, ModelCompatibilityValidationMixin, ComplexValidationMixin, GRPOVllmValidationMixin,