From 3a350765131fc570a4e67db466251c1560f6a706 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 24 Aug 2025 00:49:13 +0000 Subject: [PATCH] seems to be working? --- examples/streaming/streaming-pretrain.yml | 61 +++++ examples/streaming/streaming-sft.yml | 52 +++++ src/axolotl/utils/data/sft.py | 100 ++++++-- src/axolotl/utils/data/shared.py | 39 +++- src/axolotl/utils/data/streaming.py | 150 ++++++++++++ src/axolotl/utils/data/streaming_multipack.py | 219 ++++++++++++++++++ src/axolotl/utils/data/streaming_sft.py | 158 +++++++++++++ .../utils/data/streaming_sft_simple.py | 211 +++++++++++++++++ src/axolotl/utils/data/utils.py | 13 +- src/axolotl/utils/schemas/config.py | 6 + src/axolotl/utils/schemas/validation.py | 18 ++ 11 files changed, 1004 insertions(+), 23 deletions(-) create mode 100644 examples/streaming/streaming-pretrain.yml create mode 100644 examples/streaming/streaming-sft.yml create mode 100644 src/axolotl/utils/data/streaming.py create mode 100644 src/axolotl/utils/data/streaming_multipack.py create mode 100644 src/axolotl/utils/data/streaming_sft.py create mode 100644 src/axolotl/utils/data/streaming_sft_simple.py diff --git a/examples/streaming/streaming-pretrain.yml b/examples/streaming/streaming-pretrain.yml new file mode 100644 index 000000000..fdfba8621 --- /dev/null +++ b/examples/streaming/streaming-pretrain.yml @@ -0,0 +1,61 @@ +# Example configuration for streaming pretraining +# This demonstrates how to pretrain on large datasets that don't fit in memory + +base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +model_type: LlamaForCausalLM +tokenizer_type: LlamaTokenizer + +# Required: max_steps for streaming pretraining +max_steps: 10000 + +# Pretraining dataset configuration +# These are automatically streamed +pretraining_dataset: + - path: allenai/c4 + name: en + type: pretrain + # Optional: skip N samples (useful for resuming) + # skip: 1000000 + +# Can also use multiple pretraining datasets +# pretraining_dataset: +# - path: allenai/c4 +# name: en +# type: pretrain +# - path: HuggingFaceFW/fineweb +# type: pretrain + +val_set_size: 0.0 + +# Sequence and packing configuration +sequence_len: 2048 +sample_packing: true +pretrain_multipack_attn: true +pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching + +# Training hyperparameters +gradient_accumulation_steps: 8 +micro_batch_size: 4 +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 3e-4 + +# Memory optimizations +bf16: auto +tf32: false +gradient_checkpointing: true +flash_attention: true + +# Checkpointing and logging +output_dir: ./outputs/pretrain-streaming +logging_steps: 10 +save_steps: 500 +save_total_limit: 3 # Keep only last 3 checkpoints + +# Warmup +warmup_ratio: 0.1 + +# Optional: enable wandb for monitoring +# wandb_project: streaming-pretrain +# wandb_entity: your-entity +# wandb_name: c4-pretrain diff --git a/examples/streaming/streaming-sft.yml b/examples/streaming/streaming-sft.yml new file mode 100644 index 000000000..10b1d812a --- /dev/null +++ b/examples/streaming/streaming-sft.yml @@ -0,0 +1,52 @@ +# Example configuration for streaming SFT training +# This enables training on datasets larger than memory by streaming them from HuggingFace Hub + +base_model: HuggingFaceTB/SmolLM2-135M + +# Enable streaming mode for datasets +streaming: true + +# When using streaming, max_steps is required +max_steps: 3 # Just test a few steps + +# Training datasets - these will be streamed +# datasets: +# - path: tatsu-lab/alpaca +# type: alpaca +# split: train + +pretraining_dataset: + - path: tatsu-lab/alpaca + type: alpaca + split: train + +# Dataset configuration +sequence_len: 2048 +sample_packing: true # Enable multipack batching +pretrain_multipack_attn: true # Enable multipack attention masking +pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT) +special_tokens: + pad_token: <|endoftext|> + +# Training hyperparameters +gradient_accumulation_steps: 4 +micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples +num_epochs: 1 # With streaming, typically use max_steps instead +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-5 + +# Enable efficient training +bf16: auto +tf32: false +gradient_checkpointing: true +flash_attention: true # Enable flash attention with multipack + +# Logging and checkpointing +logging_steps: 10 +eval_steps: 100 +save_steps: 200 +output_dir: ./outputs/streaming-model + +# Warmup +warmup_steps: 100 diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 2ae7d9052..326d7943e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -26,6 +26,8 @@ from axolotl.utils.data.shared import ( save_preprocessed_dataset, try_load_from_hub, ) +from axolotl.utils.data.streaming import wrap_streaming_sft_dataset +from axolotl.utils.data.streaming_sft import wrap_streaming_sft_dataset_optimized from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, handle_long_seq_in_dataset, @@ -73,7 +75,7 @@ def _prepare_standard_dataset( tokenizer: PreTrainedTokenizer, processor: ProcessorMixin | None, preprocess_iterable: bool, -) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]: +) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]: """Prepare standard (non-pretraining) datasets.""" def _load_datasets(): @@ -118,7 +120,14 @@ def _prepare_standard_dataset( ) # Calculate total number of training steps - if cfg.max_steps: + # For streaming datasets, we must use max_steps + if isinstance(train_dataset, IterableDataset): + if not cfg.max_steps: + raise ValueError( + "When using streaming datasets, you must set max_steps in your config" + ) + total_num_steps = cfg.max_steps + elif cfg.max_steps: total_num_steps = min( calculate_total_num_steps(cfg, train_dataset), cfg.max_steps ) @@ -342,14 +351,18 @@ def _load_raw_datasets( dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) else: dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) - if cfg.sample_packing: + + # Skip packing processing for streaming datasets - they handle it differently + if cfg.sample_packing and not isinstance(dataset, IterableDataset): dataset, _ = process_datasets_for_packing(cfg, dataset, None) - # Save the prepared dataset - dataset_hash = generate_dataset_hash_from_config( - cfg, datasets_configs, tokenizer.name_or_path - ) - save_preprocessed_dataset(cfg, dataset, dataset_hash, split) + # Skip saving for streaming datasets as they can't be cached + if not isinstance(dataset, IterableDataset): + # Save the prepared dataset + dataset_hash = generate_dataset_hash_from_config( + cfg, datasets_configs, tokenizer.name_or_path + ) + save_preprocessed_dataset(cfg, dataset, dataset_hash, split) return dataset, prompters @@ -365,8 +378,10 @@ def _load_and_process_single_dataset( ) -> tuple[Dataset | IterableDataset, Prompter | None]: """Load and process a single dataset based on the passed config.""" # Load the dataset + # Use streaming if enabled in config or if using iterable preprocessing + use_streaming = cfg.streaming or preprocess_iterable dataset = load_dataset_with_config( - dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable + dataset_config, cfg.hf_use_auth_token, streaming=use_streaming ) # Parse dataset type @@ -391,16 +406,63 @@ def _load_and_process_single_dataset( num_shards=dataset_config.shards, index=shards_idx ) - # Apply dataset wrapper - dataset_wrapper, dataset_prompter = get_dataset_wrapper( - dataset_config=dataset_config, - tokenizer=tokenizer, - cfg=cfg, - dataset_base_type=d_base_type, - dataset=dataset, - dataset_prompt_style=d_prompt_style, - processor=processor, - ) + # For streaming datasets, we need to handle tokenization differently + if isinstance(dataset, IterableDataset): + # Use pretraining's approach for multipack streaming + if cfg.sample_packing: + # Create the dataset wrapper function once + def ds_wrapper_fn(dataset=None): + wrapped_dataset, prompter = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + return wrapped_dataset, prompter + + # Use optimized streaming wrapper to avoid repeated preprocessing logs + dataset_wrapper = wrap_streaming_sft_dataset_optimized( + dataset, + tokenizer, + cfg, + ds_wrapper_fn, + max_tokens=cfg.sequence_len, + batch_size=max( + 1, cfg.sequence_len // 512 + ), # Estimate sequences per pack + seed=cfg.seed or 42, + buffer_size=cfg.pretrain_multipack_buffer_size or 1_000, + ) + else: + # Use regular streaming wrapper + dataset_wrapper = wrap_streaming_sft_dataset( + dataset, + tokenizer, + cfg, + dataset_config, + d_base_type, + d_prompt_style, + processor, + max_tokens=cfg.sequence_len, + buffer_size=10_000, + ) + + # For streaming, we don't have a specific prompter + dataset_prompter = None + else: + # Apply dataset wrapper for regular datasets + dataset_wrapper, dataset_prompter = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) return dataset_wrapper, dataset_prompter diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 1d7d37f15..7b27130f7 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -524,7 +524,9 @@ def generate_dataset_hash_from_config( return str(md5(config_str)) -def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: +def merge_datasets( + datasets: list[Dataset | IterableDataset], cfg: DictDefault +) -> Dataset | IterableDataset: """Merge multiple datasets into one with optional shuffling. Args: @@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset: Returns: Merged dataset. """ + # Check if we're dealing with streaming datasets + if any(isinstance(ds, IterableDataset) for ds in datasets): + # All datasets must be streaming for merging + if not all(isinstance(ds, IterableDataset) for ds in datasets): + raise ValueError( + "Cannot mix streaming and non-streaming datasets. " + "Either all datasets must be streaming or none." + ) + + if len(datasets) == 1: + ds = datasets[0] + # Streaming datasets handle shuffling differently + if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling: + return ds.shuffle(seed=cfg.seed, buffer_size=10_000) + return ds + + # Merge streaming datasets + LOG.info("Merging streaming datasets...") + from datasets import interleave_datasets + + # For streaming, we interleave datasets instead of concatenating + merged_dataset = interleave_datasets(datasets) + + if cfg.shuffle_merged_datasets: + LOG.debug("Shuffling merged streaming datasets...") + if cfg.curriculum_sampling: + LOG.warning( + "Shuffling merged datasets with curriculum sampling is not recommended. " + "This will randomize the order of samples." + ) + merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000) + + return merged_dataset + + # Original logic for non-streaming datasets if len(datasets) == 1: ds = datasets[0] diff --git a/src/axolotl/utils/data/streaming.py b/src/axolotl/utils/data/streaming.py new file mode 100644 index 000000000..ad8564347 --- /dev/null +++ b/src/axolotl/utils/data/streaming.py @@ -0,0 +1,150 @@ +"""Utilities for handling streaming datasets.""" + +import functools +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +from datasets import Dataset, IterableDataset +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.collators import DataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import add_position_ids + +LOG = get_logger(__name__) + + +def wrap_streaming_sft_dataset( + dataset: IterableDataset, + tokenizer: PreTrainedTokenizerBase, + cfg, + dataset_config, + d_base_type: str, + d_prompt_style: str | None, + processor: Any | None, + max_tokens: int = 2048, + buffer_size: int = 10_000, +) -> IterableDataset: + """ + Wrap a streaming SFT dataset with tokenization and optional packing. + + This is similar to wrap_pretraining_dataset but for SFT datasets. + + Args: + dataset: The streaming dataset to wrap + tokenizer: Tokenizer to use + cfg: Configuration object + dataset_config: Dataset configuration + d_base_type: Base dataset type + d_prompt_style: Prompt style + processor: Optional processor for multimodal + max_tokens: Maximum sequence length + buffer_size: Buffer size for shuffling + + Returns: + Wrapped streaming dataset ready for training + """ + + # Import here to avoid circular imports + from axolotl.utils.data.wrappers import get_dataset_wrapper + + # Apply shuffling if configured + if cfg.shuffle_merged_datasets: + LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}") + dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) + + # For streaming datasets, we need to get column names from the first sample + remove_columns = [] + for first_row in dataset: + remove_columns = list(first_row.keys()) + break + + # Reset dataset after peeking + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) + + # Define the encoding function - always add position_ids for compatibility + if cfg.sample_packing: + # For sample packing, we need to handle position_ids + def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]: + """Encode examples for streaming with sample packing.""" + # Convert the batch dict to a temporary Dataset for processing + temp_dataset = Dataset.from_dict(examples) + + # Apply the dataset wrapper to tokenize + wrapped_dataset, _ = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=temp_dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + # Convert to dict for processing + result = {} + if hasattr(wrapped_dataset, "to_dict"): + result = wrapped_dataset.to_dict() + else: + for key in wrapped_dataset.column_names: + result[key] = wrapped_dataset[key] + + # Add position_ids using the existing function + result = add_position_ids(result) + + # For multipack attention, we may need to drop attention_mask + if cfg.pretrain_multipack_attn and "attention_mask" in result: + del result["attention_mask"] + + return result + + encode_fn = encode_streaming_packed + else: + # Regular encoding without packing - still add position_ids for compatibility + def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]: + """Encode examples for streaming.""" + # Convert the batch dict to a temporary Dataset for processing + temp_dataset = Dataset.from_dict(examples) + + # Apply the dataset wrapper to tokenize + wrapped_dataset, _ = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=temp_dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + # Convert to dict format + result = {} + if hasattr(wrapped_dataset, "to_dict"): + result = wrapped_dataset.to_dict() + else: + for key in wrapped_dataset.column_names: + result[key] = wrapped_dataset[key] + + # Add position_ids even without packing for compatibility + result = add_position_ids(result) + + return result + + encode_fn = encode_streaming + + # Map the encoding function over the streaming dataset + dataset = dataset.map( + encode_fn, + batched=True, + batch_size=buffer_size, + remove_columns=remove_columns, + ) + + # Set format for PyTorch + dataset = dataset.with_format("torch") + + return dataset diff --git a/src/axolotl/utils/data/streaming_multipack.py b/src/axolotl/utils/data/streaming_multipack.py new file mode 100644 index 000000000..a67d4f815 --- /dev/null +++ b/src/axolotl/utils/data/streaming_multipack.py @@ -0,0 +1,219 @@ +"""Streaming dataset with multipack support for SFT.""" + +import functools +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import numpy as np +from datasets import Dataset, IterableDataset +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import process_pretraining_datasets_for_packing + +LOG = get_logger(__name__) + + +def encode_packed_sft_streaming( + collate_fn, + ds_wrapper_fn, + examples: Dict[str, List], + dataset_config, + tokenizer, + cfg, + d_base_type: str, + d_prompt_style: str | None, + processor: Any | None, + max_seq_length: int = 2048, + batch_size: int = 4, + multipack_attn: Optional[bool] = True, +) -> Dict[str, List]: + """ + Encode and pack streaming SFT data similar to how pretraining does it. + + This function: + 1. Tokenizes the examples using the dataset wrapper + 2. Adds position_ids for each sequence + 3. Uses MultipackBatchSampler to efficiently pack sequences + 4. Applies the collator to handle attention masks properly + + Args: + collate_fn: Collator function for handling batches + ds_wrapper_fn: Function to get the dataset wrapper for tokenization + examples: Dict of lists containing the raw examples + dataset_config: Configuration for the dataset + tokenizer: Tokenizer to use + cfg: Main configuration + d_base_type: Dataset base type + d_prompt_style: Prompt style + processor: Optional processor for multimodal + max_seq_length: Maximum sequence length + batch_size: Batch size for packing + multipack_attn: Whether to use multipack attention + + Returns: + Dict of packed and processed data ready for training + """ + # Import here to avoid circular imports + from axolotl.utils.data.wrappers import get_dataset_wrapper + + # Convert examples to Dataset for processing + temp_dataset = Dataset.from_dict(examples) + + # Apply the dataset wrapper to tokenize + train_dataset, _ = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=temp_dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + # Process for packing - add position_ids and filter long sequences + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, + max_seq_length, + skip_position_ids=not multipack_attn, + drop_attention_mask=multipack_attn, + ) + + # Use MultipackBatchSampler to create efficient packed batches + sampler = MultipackBatchSampler( + sampler=RandomSampler(train_dataset), + lengths=get_dataset_lengths(train_dataset), + batch_size=1, # We pack multiple sequences into one "batch" + batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch + drop_last=True, + num_processes=1, # Single process for streaming + ) + + # Collect packed data + chunked_data = defaultdict(list) + + for batch in sampler: + for data in batch: + features = train_dataset[data] + + # Clean up unnecessary fields + for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]: + if field in features: + del features[field] + + # Ensure labels exist + if "labels" not in features: + features["labels"] = features["input_ids"].copy() + + # Apply collator to handle padding and attention masks + collated_features = collate_fn(features) + + # Collect features + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data + + +def wrap_streaming_sft_dataset_with_packing( + dataset: IterableDataset, + tokenizer: PreTrainedTokenizerBase, + cfg, + dataset_config, + d_base_type: str, + d_prompt_style: str | None, + processor: Any | None, + max_tokens: int = 2048, + buffer_size: int = 10_000, +) -> IterableDataset: + """ + Wrap a streaming SFT dataset with tokenization and multipack batching. + + This creates properly packed batches with: + - Multiple sequences concatenated together + - Position IDs that reset for each sequence + - Attention masks that prevent cross-attention between sequences + + Args: + dataset: The streaming dataset to wrap + tokenizer: Tokenizer to use + cfg: Configuration object + dataset_config: Dataset configuration + d_base_type: Base dataset type + d_prompt_style: Prompt style + processor: Optional processor for multimodal + max_tokens: Maximum sequence length + buffer_size: Buffer size for streaming/shuffling + + Returns: + Wrapped streaming dataset with multipack batching + """ + + # Apply shuffling if configured + if cfg.shuffle_merged_datasets: + LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}") + dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) + + # Get column names from first sample + remove_columns = [] + for first_row in dataset: + remove_columns = list(first_row.keys()) + break + + # Reset dataset after peeking + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) + + # Create the collator for multipack + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens, + multipack_attn=cfg.pretrain_multipack_attn, + ) + + # Create the encoding function + # batch_size here refers to how many sequences to pack together to fill max_tokens + # The actual batching happens at the DataLoader level with micro_batch_size=1 + pack_batch_size = max( + 1, max_tokens // 512 + ) # Estimate based on typical sequence lengths + + encode_fn = functools.partial( + encode_packed_sft_streaming, + collate_fn, + None, # ds_wrapper_fn will be created inside + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + d_base_type=d_base_type, + d_prompt_style=d_prompt_style, + processor=processor, + max_seq_length=max_tokens, + batch_size=pack_batch_size, + multipack_attn=cfg.pretrain_multipack_attn, + ) + + # Map the encoding function over the streaming dataset + # This will process data in batches and apply packing + dataset = dataset.map( + encode_fn, + batched=True, + batch_size=buffer_size, # Process large batches for efficiency + remove_columns=remove_columns, + ) + + # Set format for PyTorch + dataset = dataset.with_format("torch") + + # IMPORTANT: Set micro_batch_size to 1 since we've already packed + # This prevents the trainer from trying to batch our packed sequences + cfg.micro_batch_size = 1 + + return dataset diff --git a/src/axolotl/utils/data/streaming_sft.py b/src/axolotl/utils/data/streaming_sft.py new file mode 100644 index 000000000..74bcc90e1 --- /dev/null +++ b/src/axolotl/utils/data/streaming_sft.py @@ -0,0 +1,158 @@ +"""Optimized streaming SFT with multipack support that avoids repeated preprocessing.""" + +import functools +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +from datasets import Dataset, IterableDataset +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import add_position_ids, drop_long_seq + +LOG = get_logger(__name__) + + +def encode_packed_streaming_sft( + collate_fn, + ds_wrapper: Callable, + examples: Dict[str, List], + max_seq_length: int = 2048, + batch_size: int = 4, + multipack_attn: Optional[bool] = True, +) -> Dict[str, List]: + """ + Encode streaming SFT data with packing, avoiding repeated preprocessing logs. + + This is similar to encode_packed_pretraining but skips the verbose + process_pretraining_datasets_for_packing call that logs repeatedly. + """ + # Tokenize all the examples + train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0] + + # Apply filtering and preprocessing directly without verbose logging + # Filter out long sequences + def should_keep(sample): + return drop_long_seq(sample, sequence_len=max_seq_length) + + # Convert to list for filtering (since we need to iterate anyway) + filtered_samples = [] + for i in range(len(train_dataset)): + sample = train_dataset[i] + if should_keep(sample): + # Add position_ids if needed + if not multipack_attn: # skip_position_ids=False when multipack_attn=True + sample = add_position_ids(sample) + filtered_samples.append(sample) + + # Convert back to dataset + if not filtered_samples: + return {"input_ids": [], "labels": [], "attention_mask": []} + + train_dataset = Dataset.from_list(filtered_samples) + + # Remove attention_mask if needed for multipack + if multipack_attn and "attention_mask" in train_dataset.column_names: + train_dataset = train_dataset.remove_columns("attention_mask") + + # Use MultipackBatchSampler to create efficient packed batches + sampler = MultipackBatchSampler( + sampler=RandomSampler(train_dataset), + lengths=get_dataset_lengths(train_dataset), + batch_size=1, + batch_max_len=batch_size * max_seq_length, + drop_last=True, + num_processes=1, + ) + + # Collect packed data + chunked_data = defaultdict(list) + + for batch in sampler: + for data in batch: + features = train_dataset[data] + # Clean up unnecessary fields + for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]: + if field in features: + del features[field] + # Ensure labels exist + if "labels" not in features: + features["labels"] = features["input_ids"].copy() + # Apply collator + collated_features = collate_fn(features) + + # Collect features + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data + + +def wrap_streaming_sft_dataset_optimized( + dataset: IterableDataset, + tokenizer: PreTrainedTokenizerBase, + cfg, + ds_wrapper_fn: Callable, + max_tokens: int = 2048, + batch_size: int = 4, + seed: int = 42, + buffer_size: int = 1000, +) -> IterableDataset: + """ + Wrap a streaming SFT dataset with optimized multipack batching. + + This avoids the repeated preprocessing logs by using a custom encoder + that applies filtering and position_ids directly. + """ + # Create collator for packing + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens, + multipack_attn=cfg.pretrain_multipack_attn, + ) + + # Create optimized encode function + encode = functools.partial( + encode_packed_streaming_sft, + collate_fn, + ds_wrapper_fn, + max_seq_length=max_tokens, + batch_size=batch_size, + multipack_attn=cfg.pretrain_multipack_attn, + ) + + # Apply shuffling if configured + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + else: + LOG.debug("NOT shuffling merged streaming datasets") + + # Get column names to remove + remove_columns = [] + for first_row in dataset: + remove_columns = list(first_row.keys()) + break + + # Reset dataset after peeking + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + + # Map the optimized encoding function + dataset = dataset.map( + encode, + batched=True, + batch_size=buffer_size, + remove_columns=remove_columns, + ) + + # Set micro_batch_size to 1 since we've already packed + cfg.micro_batch_size = 1 + + return dataset diff --git a/src/axolotl/utils/data/streaming_sft_simple.py b/src/axolotl/utils/data/streaming_sft_simple.py new file mode 100644 index 000000000..045ded991 --- /dev/null +++ b/src/axolotl/utils/data/streaming_sft_simple.py @@ -0,0 +1,211 @@ +"""Simple streaming SFT with multipack support.""" + +import functools +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +from datasets import Dataset, IterableDataset +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.logging import get_logger +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import process_pretraining_datasets_for_packing + +LOG = get_logger(__name__) + + +class StreamingMultipackDataset: + """Dataset that handles streaming with multipack on-the-fly.""" + + def __init__( + self, + base_dataset: IterableDataset, + tokenizer: PreTrainedTokenizerBase, + cfg, + dataset_config, + d_base_type: str, + d_prompt_style: str | None, + processor: Any | None, + max_tokens: int = 2048, + pack_length: int = 4, # How many sequences to collect before packing + ): + self.base_dataset = base_dataset + self.tokenizer = tokenizer + self.cfg = cfg + self.max_tokens = max_tokens + self.pack_length = pack_length + + # Create the dataset wrapper once + from axolotl.utils.data.wrappers import get_dataset_wrapper + + # Create a dummy dataset to get the wrapper + dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]} + dummy_dataset = Dataset.from_dict(dummy_data) + + self.dataset_wrapper, _ = get_dataset_wrapper( + dataset_config=dataset_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset=dummy_dataset, + dataset_prompt_style=d_prompt_style, + processor=processor, + ) + + # Create collator for packing + self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens, + multipack_attn=cfg.pretrain_multipack_attn, + ) + + def __iter__(self): + """Iterator that yields packed samples.""" + buffer = [] + + for sample in self.base_dataset: + # Convert single sample to dataset for processing + temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()}) + + # Tokenize using the dataset wrapper + try: + tokenized = self.dataset_wrapper.__class__( + temp_dataset, + **{ + k: v + for k, v in self.dataset_wrapper.__dict__.items() + if not k.startswith("_") + }, + ) + + # Get the tokenized sample + if len(tokenized) > 0: + tokenized_sample = tokenized[0] + + # Add to buffer + buffer.append(tokenized_sample) + + # When buffer is full, pack and yield + if len(buffer) >= self.pack_length: + packed_sample = self._pack_buffer(buffer) + if packed_sample: + yield packed_sample + buffer = [] + + except Exception as e: + LOG.warning(f"Failed to process sample: {e}") + continue + + # Process remaining buffer + if buffer: + packed_sample = self._pack_buffer(buffer) + if packed_sample: + yield packed_sample + + def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]: + """Pack a buffer of tokenized samples.""" + if not buffer: + return None + + try: + # Create dataset from buffer + temp_dataset = Dataset.from_list(buffer) + + # Add position_ids and process for packing + temp_dataset = process_pretraining_datasets_for_packing( + temp_dataset, + self.max_tokens, + skip_position_ids=not self.cfg.pretrain_multipack_attn, + drop_attention_mask=self.cfg.pretrain_multipack_attn, + ) + + # Use MultipackBatchSampler to create packed batches + sampler = MultipackBatchSampler( + sampler=RandomSampler(temp_dataset), + lengths=get_dataset_lengths(temp_dataset), + batch_size=1, + batch_max_len=self.max_tokens, + drop_last=False, + num_processes=1, + ) + + # Get packed data + for batch in sampler: + if batch and batch[0]: # Check if we have data + features = [] + for idx in batch[0]: # batch[0] contains the indices + sample = temp_dataset[idx] + if "labels" not in sample: + sample["labels"] = sample["input_ids"].copy() + features.append(sample) + + # Apply collator to create final packed sample + if features: + packed = self.collator(features) + return { + k: v.squeeze(0) if v.dim() > 1 else v + for k, v in packed.items() + } + + return None + + except Exception as e: + LOG.warning(f"Failed to pack buffer: {e}") + return None + + +def wrap_streaming_sft_dataset_simple( + dataset: IterableDataset, + tokenizer: PreTrainedTokenizerBase, + cfg, + dataset_config, + d_base_type: str, + d_prompt_style: str | None, + processor: Any | None, + max_tokens: int = 2048, + buffer_size: int = 10_000, +) -> IterableDataset: + """ + Wrap a streaming SFT dataset with simple multipack batching. + + This approach processes samples in small groups rather than large batches, + avoiding the repeated processing issue. + """ + + # Apply shuffling if configured + if cfg.shuffle_merged_datasets: + LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}") + dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) + + # Create the streaming multipack dataset + multipack_dataset = StreamingMultipackDataset( + dataset, + tokenizer, + cfg, + dataset_config, + d_base_type, + d_prompt_style, + processor, + max_tokens, + pack_length=max(1, max_tokens // 512), # Estimate sequences per pack + ) + + # Convert back to IterableDataset + class IterableWrapper: + def __init__(self, dataset): + self.dataset = dataset + + def __iter__(self): + return iter(self.dataset) + + wrapped = IterableWrapper(multipack_dataset) + + # Set micro_batch_size to 1 since sequences are already packed + cfg.micro_batch_size = 1 + + return wrapped diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 4868576a0..5fab82299 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): def handle_long_seq_in_dataset( - dataset: Dataset, sequence_len: int, cfg: DictDefault -) -> Dataset: + dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault +) -> Dataset | IterableDataset: """Remove sequences longer than configured maximum from dataset. Args: @@ -190,7 +190,14 @@ def handle_long_seq_in_dataset( Returns: Filtered dataset with long sequences removed. """ - if "input_ids" not in dataset.column_names: + # Streaming datasets don't support filtering the same way + if isinstance(dataset, IterableDataset): + LOG.info( + "Streaming dataset detected - long sequence filtering will be done on-the-fly" + ) + return dataset + + if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names: LOG.warning( "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " "expected for reward modeling." diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4d660d4b7..9bf77a4e4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -244,6 +244,12 @@ class AxolotlInputConfig( dataloader_num_workers: int | None = None dataloader_prefetch_factor: int | None = None dataloader_drop_last: bool | None = None + streaming: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory" + }, + ) accelerator_config: dict[str, Any] | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 791894990..68b093b07 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1074,6 +1074,24 @@ class PretrainingValidationMixin: data["accelerator_config"]["dispatch_batches"] = False return data + @model_validator(mode="before") + @classmethod + def check_pretraining_split_batches_accelerate(cls, data): + # alternatively set ACCELERATE_SPLIT_BATCHES=False + if data.get("streaming"): + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + class ModelCompatibilityValidationMixin: """Validation methods for specific model compatibility."""