Compare commits

...

3 Commits

Author SHA1 Message Date
Dan Saunders
d3bea3a2eb broken 2025-08-25 16:51:36 +00:00
Dan Saunders
2e2302aae3 remove unused 2025-08-25 15:46:25 +00:00
Dan Saunders
3a35076513 seems to be working? 2025-08-25 14:22:32 +00:00
6 changed files with 302 additions and 23 deletions

View File

@@ -26,6 +26,7 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
@@ -73,7 +74,7 @@ def _prepare_standard_dataset(
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
def _load_datasets():
@@ -118,7 +119,14 @@ def _prepare_standard_dataset(
)
# Calculate total number of training steps
if cfg.max_steps:
# For streaming datasets, we must use max_steps
if isinstance(train_dataset, IterableDataset):
if not cfg.max_steps:
raise ValueError(
"When using streaming datasets, you must set max_steps in your config"
)
total_num_steps = cfg.max_steps
elif cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
@@ -342,14 +350,18 @@ def _load_raw_datasets(
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
# Skip packing processing for streaming datasets - they handle it differently
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Skip saving for streaming datasets as they can't be cached
if not isinstance(dataset, IterableDataset):
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -365,8 +377,10 @@ def _load_and_process_single_dataset(
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
# Use streaming if enabled in config or if using iterable preprocessing
use_streaming = cfg.streaming or preprocess_iterable
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
)
# Parse dataset type
@@ -391,16 +405,63 @@ def _load_and_process_single_dataset(
num_shards=dataset_config.shards, index=shards_idx
)
# Apply dataset wrapper
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# For streaming datasets, we need to handle tokenization differently
if isinstance(dataset, IterableDataset):
# Use pretraining's approach for multipack streaming
if cfg.sample_packing:
# Create the dataset wrapper function once
def ds_wrapper_fn(dataset=None):
wrapped_dataset, prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return wrapped_dataset, prompter
# Use pretraining wrapper for efficient streaming SFT with packing
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
dataset_wrapper = wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size,
)
else:
# Use regular streaming wrapper
dataset_wrapper = wrap_streaming_sft_dataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens=cfg.sequence_len,
buffer_size=10_000,
)
# For streaming, we don't have a specific prompter
dataset_prompter = None
else:
# Apply dataset wrapper for regular datasets
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return dataset_wrapper, dataset_prompter

View File

@@ -524,7 +524,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Returns:
Merged dataset.
"""
# Check if we're dealing with streaming datasets
if any(isinstance(ds, IterableDataset) for ds in datasets):
# All datasets must be streaming for merging
if not all(isinstance(ds, IterableDataset) for ds in datasets):
raise ValueError(
"Cannot mix streaming and non-streaming datasets. "
"Either all datasets must be streaming or none."
)
if len(datasets) == 1:
ds = datasets[0]
# Streaming datasets handle shuffling differently
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
return ds
# Merge streaming datasets
LOG.info("Merging streaming datasets...")
from datasets import interleave_datasets
# For streaming, we interleave datasets instead of concatenating
merged_dataset = interleave_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged streaming datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
return merged_dataset
# Original logic for non-streaming datasets
if len(datasets) == 1:
ds = datasets[0]

View File

@@ -0,0 +1,150 @@
"""Utilities for handling streaming datasets."""
import functools
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import add_position_ids
LOG = get_logger(__name__)
def wrap_streaming_sft_dataset(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and optional packing.
This is similar to wrap_pretraining_dataset but for SFT datasets.
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for shuffling
Returns:
Wrapped streaming dataset ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# For streaming datasets, we need to get column names from the first sample
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Define the encoding function - always add position_ids for compatibility
if cfg.sample_packing:
# For sample packing, we need to handle position_ids
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming with sample packing."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict for processing
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids using the existing function
result = add_position_ids(result)
# For multipack attention, we may need to drop attention_mask
if cfg.pretrain_multipack_attn and "attention_mask" in result:
del result["attention_mask"]
return result
encode_fn = encode_streaming_packed
else:
# Regular encoding without packing - still add position_ids for compatibility
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict format
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids even without packing for compatibility
result = add_position_ids(result)
return result
encode_fn = encode_streaming
# Map the encoding function over the streaming dataset
dataset = dataset.map(
encode_fn,
batched=True,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
return dataset

View File

@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
) -> Dataset | IterableDataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
@@ -190,7 +190,14 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
# Streaming datasets don't support filtering the same way
if isinstance(dataset, IterableDataset):
LOG.info(
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
)
return dataset
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."

View File

@@ -244,6 +244,12 @@ class AxolotlInputConfig(
dataloader_num_workers: int | None = None
dataloader_prefetch_factor: int | None = None
dataloader_drop_last: bool | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
},
)
accelerator_config: dict[str, Any] | None = None

View File

@@ -1074,6 +1074,24 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("streaming"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""