Compare commits
3 Commits
accelerato
...
streaming-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3bea3a2eb | ||
|
|
2e2302aae3 | ||
|
|
3a35076513 |
@@ -26,6 +26,7 @@ from axolotl.utils.data.shared import (
|
||||
save_preprocessed_dataset,
|
||||
try_load_from_hub,
|
||||
)
|
||||
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
|
||||
from axolotl.utils.data.utils import (
|
||||
deduplicate_and_log_datasets,
|
||||
handle_long_seq_in_dataset,
|
||||
@@ -73,7 +74,7 @@ def _prepare_standard_dataset(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
processor: ProcessorMixin | None,
|
||||
preprocess_iterable: bool,
|
||||
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
|
||||
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
||||
"""Prepare standard (non-pretraining) datasets."""
|
||||
|
||||
def _load_datasets():
|
||||
@@ -118,7 +119,14 @@ def _prepare_standard_dataset(
|
||||
)
|
||||
|
||||
# Calculate total number of training steps
|
||||
if cfg.max_steps:
|
||||
# For streaming datasets, we must use max_steps
|
||||
if isinstance(train_dataset, IterableDataset):
|
||||
if not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"When using streaming datasets, you must set max_steps in your config"
|
||||
)
|
||||
total_num_steps = cfg.max_steps
|
||||
elif cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
@@ -342,14 +350,18 @@ def _load_raw_datasets(
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||
else:
|
||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||
if cfg.sample_packing:
|
||||
|
||||
# Skip packing processing for streaming datasets - they handle it differently
|
||||
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
|
||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||
|
||||
# Save the prepared dataset
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||
# Skip saving for streaming datasets as they can't be cached
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
# Save the prepared dataset
|
||||
dataset_hash = generate_dataset_hash_from_config(
|
||||
cfg, datasets_configs, tokenizer.name_or_path
|
||||
)
|
||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||
|
||||
return dataset, prompters
|
||||
|
||||
@@ -365,8 +377,10 @@ def _load_and_process_single_dataset(
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Load and process a single dataset based on the passed config."""
|
||||
# Load the dataset
|
||||
# Use streaming if enabled in config or if using iterable preprocessing
|
||||
use_streaming = cfg.streaming or preprocess_iterable
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
|
||||
)
|
||||
|
||||
# Parse dataset type
|
||||
@@ -391,16 +405,63 @@ def _load_and_process_single_dataset(
|
||||
num_shards=dataset_config.shards, index=shards_idx
|
||||
)
|
||||
|
||||
# Apply dataset wrapper
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
dataset_config=dataset_config,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset_base_type=d_base_type,
|
||||
dataset=dataset,
|
||||
dataset_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
# For streaming datasets, we need to handle tokenization differently
|
||||
if isinstance(dataset, IterableDataset):
|
||||
# Use pretraining's approach for multipack streaming
|
||||
if cfg.sample_packing:
|
||||
# Create the dataset wrapper function once
|
||||
def ds_wrapper_fn(dataset=None):
|
||||
wrapped_dataset, prompter = get_dataset_wrapper(
|
||||
dataset_config=dataset_config,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset_base_type=d_base_type,
|
||||
dataset=dataset,
|
||||
dataset_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
return wrapped_dataset, prompter
|
||||
|
||||
# Use pretraining wrapper for efficient streaming SFT with packing
|
||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||
|
||||
dataset_wrapper = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_fn,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed,
|
||||
buffer_size=cfg.pretrain_multipack_buffer_size,
|
||||
)
|
||||
else:
|
||||
# Use regular streaming wrapper
|
||||
dataset_wrapper = wrap_streaming_sft_dataset(
|
||||
dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
dataset_config,
|
||||
d_base_type,
|
||||
d_prompt_style,
|
||||
processor,
|
||||
max_tokens=cfg.sequence_len,
|
||||
buffer_size=10_000,
|
||||
)
|
||||
|
||||
# For streaming, we don't have a specific prompter
|
||||
dataset_prompter = None
|
||||
else:
|
||||
# Apply dataset wrapper for regular datasets
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
dataset_config=dataset_config,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset_base_type=d_base_type,
|
||||
dataset=dataset,
|
||||
dataset_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
|
||||
|
||||
@@ -524,7 +524,9 @@ def generate_dataset_hash_from_config(
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
def merge_datasets(
|
||||
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||
) -> Dataset | IterableDataset:
|
||||
"""Merge multiple datasets into one with optional shuffling.
|
||||
|
||||
Args:
|
||||
@@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
||||
Returns:
|
||||
Merged dataset.
|
||||
"""
|
||||
# Check if we're dealing with streaming datasets
|
||||
if any(isinstance(ds, IterableDataset) for ds in datasets):
|
||||
# All datasets must be streaming for merging
|
||||
if not all(isinstance(ds, IterableDataset) for ds in datasets):
|
||||
raise ValueError(
|
||||
"Cannot mix streaming and non-streaming datasets. "
|
||||
"Either all datasets must be streaming or none."
|
||||
)
|
||||
|
||||
if len(datasets) == 1:
|
||||
ds = datasets[0]
|
||||
# Streaming datasets handle shuffling differently
|
||||
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
|
||||
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||
return ds
|
||||
|
||||
# Merge streaming datasets
|
||||
LOG.info("Merging streaming datasets...")
|
||||
from datasets import interleave_datasets
|
||||
|
||||
# For streaming, we interleave datasets instead of concatenating
|
||||
merged_dataset = interleave_datasets(datasets)
|
||||
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.debug("Shuffling merged streaming datasets...")
|
||||
if cfg.curriculum_sampling:
|
||||
LOG.warning(
|
||||
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||
"This will randomize the order of samples."
|
||||
)
|
||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||
|
||||
return merged_dataset
|
||||
|
||||
# Original logic for non-streaming datasets
|
||||
if len(datasets) == 1:
|
||||
ds = datasets[0]
|
||||
|
||||
|
||||
150
src/axolotl/utils/data/streaming.py
Normal file
150
src/axolotl/utils/data/streaming.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Utilities for handling streaming datasets."""
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch.utils.data import RandomSampler
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
from axolotl.utils.trainer import add_position_ids
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def wrap_streaming_sft_dataset(
|
||||
dataset: IterableDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
dataset_config,
|
||||
d_base_type: str,
|
||||
d_prompt_style: str | None,
|
||||
processor: Any | None,
|
||||
max_tokens: int = 2048,
|
||||
buffer_size: int = 10_000,
|
||||
) -> IterableDataset:
|
||||
"""
|
||||
Wrap a streaming SFT dataset with tokenization and optional packing.
|
||||
|
||||
This is similar to wrap_pretraining_dataset but for SFT datasets.
|
||||
|
||||
Args:
|
||||
dataset: The streaming dataset to wrap
|
||||
tokenizer: Tokenizer to use
|
||||
cfg: Configuration object
|
||||
dataset_config: Dataset configuration
|
||||
d_base_type: Base dataset type
|
||||
d_prompt_style: Prompt style
|
||||
processor: Optional processor for multimodal
|
||||
max_tokens: Maximum sequence length
|
||||
buffer_size: Buffer size for shuffling
|
||||
|
||||
Returns:
|
||||
Wrapped streaming dataset ready for training
|
||||
"""
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||
|
||||
# Apply shuffling if configured
|
||||
if cfg.shuffle_merged_datasets:
|
||||
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
||||
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||
|
||||
# For streaming datasets, we need to get column names from the first sample
|
||||
remove_columns = []
|
||||
for first_row in dataset:
|
||||
remove_columns = list(first_row.keys())
|
||||
break
|
||||
|
||||
# Reset dataset after peeking
|
||||
if cfg.shuffle_merged_datasets:
|
||||
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||
|
||||
# Define the encoding function - always add position_ids for compatibility
|
||||
if cfg.sample_packing:
|
||||
# For sample packing, we need to handle position_ids
|
||||
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
|
||||
"""Encode examples for streaming with sample packing."""
|
||||
# Convert the batch dict to a temporary Dataset for processing
|
||||
temp_dataset = Dataset.from_dict(examples)
|
||||
|
||||
# Apply the dataset wrapper to tokenize
|
||||
wrapped_dataset, _ = get_dataset_wrapper(
|
||||
dataset_config=dataset_config,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset_base_type=d_base_type,
|
||||
dataset=temp_dataset,
|
||||
dataset_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Convert to dict for processing
|
||||
result = {}
|
||||
if hasattr(wrapped_dataset, "to_dict"):
|
||||
result = wrapped_dataset.to_dict()
|
||||
else:
|
||||
for key in wrapped_dataset.column_names:
|
||||
result[key] = wrapped_dataset[key]
|
||||
|
||||
# Add position_ids using the existing function
|
||||
result = add_position_ids(result)
|
||||
|
||||
# For multipack attention, we may need to drop attention_mask
|
||||
if cfg.pretrain_multipack_attn and "attention_mask" in result:
|
||||
del result["attention_mask"]
|
||||
|
||||
return result
|
||||
|
||||
encode_fn = encode_streaming_packed
|
||||
else:
|
||||
# Regular encoding without packing - still add position_ids for compatibility
|
||||
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
|
||||
"""Encode examples for streaming."""
|
||||
# Convert the batch dict to a temporary Dataset for processing
|
||||
temp_dataset = Dataset.from_dict(examples)
|
||||
|
||||
# Apply the dataset wrapper to tokenize
|
||||
wrapped_dataset, _ = get_dataset_wrapper(
|
||||
dataset_config=dataset_config,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset_base_type=d_base_type,
|
||||
dataset=temp_dataset,
|
||||
dataset_prompt_style=d_prompt_style,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
# Convert to dict format
|
||||
result = {}
|
||||
if hasattr(wrapped_dataset, "to_dict"):
|
||||
result = wrapped_dataset.to_dict()
|
||||
else:
|
||||
for key in wrapped_dataset.column_names:
|
||||
result[key] = wrapped_dataset[key]
|
||||
|
||||
# Add position_ids even without packing for compatibility
|
||||
result = add_position_ids(result)
|
||||
|
||||
return result
|
||||
|
||||
encode_fn = encode_streaming
|
||||
|
||||
# Map the encoding function over the streaming dataset
|
||||
dataset = dataset.map(
|
||||
encode_fn,
|
||||
batched=True,
|
||||
batch_size=buffer_size,
|
||||
remove_columns=remove_columns,
|
||||
)
|
||||
|
||||
# Set format for PyTorch
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
return dataset
|
||||
@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||
|
||||
|
||||
def handle_long_seq_in_dataset(
|
||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset:
|
||||
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
|
||||
) -> Dataset | IterableDataset:
|
||||
"""Remove sequences longer than configured maximum from dataset.
|
||||
|
||||
Args:
|
||||
@@ -190,7 +190,14 @@ def handle_long_seq_in_dataset(
|
||||
Returns:
|
||||
Filtered dataset with long sequences removed.
|
||||
"""
|
||||
if "input_ids" not in dataset.column_names:
|
||||
# Streaming datasets don't support filtering the same way
|
||||
if isinstance(dataset, IterableDataset):
|
||||
LOG.info(
|
||||
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
|
||||
)
|
||||
return dataset
|
||||
|
||||
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
|
||||
LOG.warning(
|
||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||
"expected for reward modeling."
|
||||
|
||||
@@ -244,6 +244,12 @@ class AxolotlInputConfig(
|
||||
dataloader_num_workers: int | None = None
|
||||
dataloader_prefetch_factor: int | None = None
|
||||
dataloader_drop_last: bool | None = None
|
||||
streaming: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
|
||||
},
|
||||
)
|
||||
|
||||
accelerator_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -1074,6 +1074,24 @@ class PretrainingValidationMixin:
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_pretraining_split_batches_accelerate(cls, data):
|
||||
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||
if data.get("streaming"):
|
||||
accelerator_config = data.get("accelerator_config", {})
|
||||
if not accelerator_config:
|
||||
data["accelerator_config"] = {
|
||||
"split_batches": False,
|
||||
"dispatch_batches": False,
|
||||
}
|
||||
else:
|
||||
if accelerator_config.get("split_batches") is None:
|
||||
data["accelerator_config"]["split_batches"] = False
|
||||
if accelerator_config.get("dispatch_batches") is None:
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
|
||||
class ModelCompatibilityValidationMixin:
|
||||
"""Validation methods for specific model compatibility."""
|
||||
|
||||
Reference in New Issue
Block a user