seems to be working?

This commit is contained in:
Dan Saunders
2025-08-24 00:49:13 +00:00
parent 79ddaebe9a
commit 3a35076513
11 changed files with 1004 additions and 23 deletions

View File

@@ -0,0 +1,61 @@
# Example configuration for streaming pretraining
# This demonstrates how to pretrain on large datasets that don't fit in memory
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
# Required: max_steps for streaming pretraining
max_steps: 10000
# Pretraining dataset configuration
# These are automatically streamed
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
# Optional: skip N samples (useful for resuming)
# skip: 1000000
# Can also use multiple pretraining datasets
# pretraining_dataset:
# - path: allenai/c4
# name: en
# type: pretrain
# - path: HuggingFaceFW/fineweb
# type: pretrain
val_set_size: 0.0
# Sequence and packing configuration
sequence_len: 2048
sample_packing: true
pretrain_multipack_attn: true
pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching
# Training hyperparameters
gradient_accumulation_steps: 8
micro_batch_size: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 3e-4
# Memory optimizations
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# Checkpointing and logging
output_dir: ./outputs/pretrain-streaming
logging_steps: 10
save_steps: 500
save_total_limit: 3 # Keep only last 3 checkpoints
# Warmup
warmup_ratio: 0.1
# Optional: enable wandb for monitoring
# wandb_project: streaming-pretrain
# wandb_entity: your-entity
# wandb_name: c4-pretrain

View File

@@ -0,0 +1,52 @@
# Example configuration for streaming SFT training
# This enables training on datasets larger than memory by streaming them from HuggingFace Hub
base_model: HuggingFaceTB/SmolLM2-135M
# Enable streaming mode for datasets
streaming: true
# When using streaming, max_steps is required
max_steps: 3 # Just test a few steps
# Training datasets - these will be streamed
# datasets:
# - path: tatsu-lab/alpaca
# type: alpaca
# split: train
pretraining_dataset:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Dataset configuration
sequence_len: 2048
sample_packing: true # Enable multipack batching
pretrain_multipack_attn: true # Enable multipack attention masking
pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT)
special_tokens:
pad_token: <|endoftext|>
# Training hyperparameters
gradient_accumulation_steps: 4
micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples
num_epochs: 1 # With streaming, typically use max_steps instead
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
# Enable efficient training
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true # Enable flash attention with multipack
# Logging and checkpointing
logging_steps: 10
eval_steps: 100
save_steps: 200
output_dir: ./outputs/streaming-model
# Warmup
warmup_steps: 100

View File

@@ -26,6 +26,8 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
from axolotl.utils.data.streaming_sft import wrap_streaming_sft_dataset_optimized
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
@@ -73,7 +75,7 @@ def _prepare_standard_dataset(
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
def _load_datasets():
@@ -118,7 +120,14 @@ def _prepare_standard_dataset(
)
# Calculate total number of training steps
if cfg.max_steps:
# For streaming datasets, we must use max_steps
if isinstance(train_dataset, IterableDataset):
if not cfg.max_steps:
raise ValueError(
"When using streaming datasets, you must set max_steps in your config"
)
total_num_steps = cfg.max_steps
elif cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
@@ -342,14 +351,18 @@ def _load_raw_datasets(
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
# Skip packing processing for streaming datasets - they handle it differently
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Skip saving for streaming datasets as they can't be cached
if not isinstance(dataset, IterableDataset):
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -365,8 +378,10 @@ def _load_and_process_single_dataset(
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
# Use streaming if enabled in config or if using iterable preprocessing
use_streaming = cfg.streaming or preprocess_iterable
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
)
# Parse dataset type
@@ -391,16 +406,63 @@ def _load_and_process_single_dataset(
num_shards=dataset_config.shards, index=shards_idx
)
# Apply dataset wrapper
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# For streaming datasets, we need to handle tokenization differently
if isinstance(dataset, IterableDataset):
# Use pretraining's approach for multipack streaming
if cfg.sample_packing:
# Create the dataset wrapper function once
def ds_wrapper_fn(dataset=None):
wrapped_dataset, prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return wrapped_dataset, prompter
# Use optimized streaming wrapper to avoid repeated preprocessing logs
dataset_wrapper = wrap_streaming_sft_dataset_optimized(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=cfg.sequence_len,
batch_size=max(
1, cfg.sequence_len // 512
), # Estimate sequences per pack
seed=cfg.seed or 42,
buffer_size=cfg.pretrain_multipack_buffer_size or 1_000,
)
else:
# Use regular streaming wrapper
dataset_wrapper = wrap_streaming_sft_dataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens=cfg.sequence_len,
buffer_size=10_000,
)
# For streaming, we don't have a specific prompter
dataset_prompter = None
else:
# Apply dataset wrapper for regular datasets
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return dataset_wrapper, dataset_prompter

View File

@@ -524,7 +524,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Returns:
Merged dataset.
"""
# Check if we're dealing with streaming datasets
if any(isinstance(ds, IterableDataset) for ds in datasets):
# All datasets must be streaming for merging
if not all(isinstance(ds, IterableDataset) for ds in datasets):
raise ValueError(
"Cannot mix streaming and non-streaming datasets. "
"Either all datasets must be streaming or none."
)
if len(datasets) == 1:
ds = datasets[0]
# Streaming datasets handle shuffling differently
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
return ds
# Merge streaming datasets
LOG.info("Merging streaming datasets...")
from datasets import interleave_datasets
# For streaming, we interleave datasets instead of concatenating
merged_dataset = interleave_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged streaming datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
return merged_dataset
# Original logic for non-streaming datasets
if len(datasets) == 1:
ds = datasets[0]

View File

@@ -0,0 +1,150 @@
"""Utilities for handling streaming datasets."""
import functools
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import add_position_ids
LOG = get_logger(__name__)
def wrap_streaming_sft_dataset(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and optional packing.
This is similar to wrap_pretraining_dataset but for SFT datasets.
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for shuffling
Returns:
Wrapped streaming dataset ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# For streaming datasets, we need to get column names from the first sample
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Define the encoding function - always add position_ids for compatibility
if cfg.sample_packing:
# For sample packing, we need to handle position_ids
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming with sample packing."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict for processing
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids using the existing function
result = add_position_ids(result)
# For multipack attention, we may need to drop attention_mask
if cfg.pretrain_multipack_attn and "attention_mask" in result:
del result["attention_mask"]
return result
encode_fn = encode_streaming_packed
else:
# Regular encoding without packing - still add position_ids for compatibility
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict format
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids even without packing for compatibility
result = add_position_ids(result)
return result
encode_fn = encode_streaming
# Map the encoding function over the streaming dataset
dataset = dataset.map(
encode_fn,
batched=True,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
return dataset

View File

@@ -0,0 +1,219 @@
"""Streaming dataset with multipack support for SFT."""
import functools
from collections import defaultdict
from typing import Any, Dict, List, Optional
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
def encode_packed_sft_streaming(
collate_fn,
ds_wrapper_fn,
examples: Dict[str, List],
dataset_config,
tokenizer,
cfg,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
"""
Encode and pack streaming SFT data similar to how pretraining does it.
This function:
1. Tokenizes the examples using the dataset wrapper
2. Adds position_ids for each sequence
3. Uses MultipackBatchSampler to efficiently pack sequences
4. Applies the collator to handle attention masks properly
Args:
collate_fn: Collator function for handling batches
ds_wrapper_fn: Function to get the dataset wrapper for tokenization
examples: Dict of lists containing the raw examples
dataset_config: Configuration for the dataset
tokenizer: Tokenizer to use
cfg: Main configuration
d_base_type: Dataset base type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_seq_length: Maximum sequence length
batch_size: Batch size for packing
multipack_attn: Whether to use multipack attention
Returns:
Dict of packed and processed data ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Convert examples to Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
train_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Process for packing - add position_ids and filter long sequences
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
drop_attention_mask=multipack_attn,
)
# Use MultipackBatchSampler to create efficient packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1, # We pack multiple sequences into one "batch"
batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch
drop_last=True,
num_processes=1, # Single process for streaming
)
# Collect packed data
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
# Clean up unnecessary fields
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
if field in features:
del features[field]
# Ensure labels exist
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
# Apply collator to handle padding and attention masks
collated_features = collate_fn(features)
# Collect features
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data
def wrap_streaming_sft_dataset_with_packing(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and multipack batching.
This creates properly packed batches with:
- Multiple sequences concatenated together
- Position IDs that reset for each sequence
- Attention masks that prevent cross-attention between sequences
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for streaming/shuffling
Returns:
Wrapped streaming dataset with multipack batching
"""
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Get column names from first sample
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Create the collator for multipack
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Create the encoding function
# batch_size here refers to how many sequences to pack together to fill max_tokens
# The actual batching happens at the DataLoader level with micro_batch_size=1
pack_batch_size = max(
1, max_tokens // 512
) # Estimate based on typical sequence lengths
encode_fn = functools.partial(
encode_packed_sft_streaming,
collate_fn,
None, # ds_wrapper_fn will be created inside
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
processor=processor,
max_seq_length=max_tokens,
batch_size=pack_batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Map the encoding function over the streaming dataset
# This will process data in batches and apply packing
dataset = dataset.map(
encode_fn,
batched=True,
batch_size=buffer_size, # Process large batches for efficiency
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
# IMPORTANT: Set micro_batch_size to 1 since we've already packed
# This prevents the trainer from trying to batch our packed sequences
cfg.micro_batch_size = 1
return dataset

View File

@@ -0,0 +1,158 @@
"""Optimized streaming SFT with multipack support that avoids repeated preprocessing."""
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import add_position_ids, drop_long_seq
LOG = get_logger(__name__)
def encode_packed_streaming_sft(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
"""
Encode streaming SFT data with packing, avoiding repeated preprocessing logs.
This is similar to encode_packed_pretraining but skips the verbose
process_pretraining_datasets_for_packing call that logs repeatedly.
"""
# Tokenize all the examples
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
# Apply filtering and preprocessing directly without verbose logging
# Filter out long sequences
def should_keep(sample):
return drop_long_seq(sample, sequence_len=max_seq_length)
# Convert to list for filtering (since we need to iterate anyway)
filtered_samples = []
for i in range(len(train_dataset)):
sample = train_dataset[i]
if should_keep(sample):
# Add position_ids if needed
if not multipack_attn: # skip_position_ids=False when multipack_attn=True
sample = add_position_ids(sample)
filtered_samples.append(sample)
# Convert back to dataset
if not filtered_samples:
return {"input_ids": [], "labels": [], "attention_mask": []}
train_dataset = Dataset.from_list(filtered_samples)
# Remove attention_mask if needed for multipack
if multipack_attn and "attention_mask" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("attention_mask")
# Use MultipackBatchSampler to create efficient packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
)
# Collect packed data
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
# Clean up unnecessary fields
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
if field in features:
del features[field]
# Ensure labels exist
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
# Apply collator
collated_features = collate_fn(features)
# Collect features
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data
def wrap_streaming_sft_dataset_optimized(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
ds_wrapper_fn: Callable,
max_tokens: int = 2048,
batch_size: int = 4,
seed: int = 42,
buffer_size: int = 1000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with optimized multipack batching.
This avoids the repeated preprocessing logs by using a custom encoder
that applies filtering and position_ids directly.
"""
# Create collator for packing
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Create optimized encode function
encode = functools.partial(
encode_packed_streaming_sft,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged streaming datasets")
# Get column names to remove
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
# Map the optimized encoding function
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set micro_batch_size to 1 since we've already packed
cfg.micro_batch_size = 1
return dataset

View File

@@ -0,0 +1,211 @@
"""Simple streaming SFT with multipack support."""
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
class StreamingMultipackDataset:
"""Dataset that handles streaming with multipack on-the-fly."""
def __init__(
self,
base_dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
pack_length: int = 4, # How many sequences to collect before packing
):
self.base_dataset = base_dataset
self.tokenizer = tokenizer
self.cfg = cfg
self.max_tokens = max_tokens
self.pack_length = pack_length
# Create the dataset wrapper once
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Create a dummy dataset to get the wrapper
dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]}
dummy_dataset = Dataset.from_dict(dummy_data)
self.dataset_wrapper, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dummy_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Create collator for packing
self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
def __iter__(self):
"""Iterator that yields packed samples."""
buffer = []
for sample in self.base_dataset:
# Convert single sample to dataset for processing
temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()})
# Tokenize using the dataset wrapper
try:
tokenized = self.dataset_wrapper.__class__(
temp_dataset,
**{
k: v
for k, v in self.dataset_wrapper.__dict__.items()
if not k.startswith("_")
},
)
# Get the tokenized sample
if len(tokenized) > 0:
tokenized_sample = tokenized[0]
# Add to buffer
buffer.append(tokenized_sample)
# When buffer is full, pack and yield
if len(buffer) >= self.pack_length:
packed_sample = self._pack_buffer(buffer)
if packed_sample:
yield packed_sample
buffer = []
except Exception as e:
LOG.warning(f"Failed to process sample: {e}")
continue
# Process remaining buffer
if buffer:
packed_sample = self._pack_buffer(buffer)
if packed_sample:
yield packed_sample
def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]:
"""Pack a buffer of tokenized samples."""
if not buffer:
return None
try:
# Create dataset from buffer
temp_dataset = Dataset.from_list(buffer)
# Add position_ids and process for packing
temp_dataset = process_pretraining_datasets_for_packing(
temp_dataset,
self.max_tokens,
skip_position_ids=not self.cfg.pretrain_multipack_attn,
drop_attention_mask=self.cfg.pretrain_multipack_attn,
)
# Use MultipackBatchSampler to create packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(temp_dataset),
lengths=get_dataset_lengths(temp_dataset),
batch_size=1,
batch_max_len=self.max_tokens,
drop_last=False,
num_processes=1,
)
# Get packed data
for batch in sampler:
if batch and batch[0]: # Check if we have data
features = []
for idx in batch[0]: # batch[0] contains the indices
sample = temp_dataset[idx]
if "labels" not in sample:
sample["labels"] = sample["input_ids"].copy()
features.append(sample)
# Apply collator to create final packed sample
if features:
packed = self.collator(features)
return {
k: v.squeeze(0) if v.dim() > 1 else v
for k, v in packed.items()
}
return None
except Exception as e:
LOG.warning(f"Failed to pack buffer: {e}")
return None
def wrap_streaming_sft_dataset_simple(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with simple multipack batching.
This approach processes samples in small groups rather than large batches,
avoiding the repeated processing issue.
"""
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Create the streaming multipack dataset
multipack_dataset = StreamingMultipackDataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens,
pack_length=max(1, max_tokens // 512), # Estimate sequences per pack
)
# Convert back to IterableDataset
class IterableWrapper:
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
return iter(self.dataset)
wrapped = IterableWrapper(multipack_dataset)
# Set micro_batch_size to 1 since sequences are already packed
cfg.micro_batch_size = 1
return wrapped

View File

@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
) -> Dataset | IterableDataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
@@ -190,7 +190,14 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
# Streaming datasets don't support filtering the same way
if isinstance(dataset, IterableDataset):
LOG.info(
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
)
return dataset
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."

View File

@@ -244,6 +244,12 @@ class AxolotlInputConfig(
dataloader_num_workers: int | None = None
dataloader_prefetch_factor: int | None = None
dataloader_drop_last: bool | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
},
)
accelerator_config: dict[str, Any] | None = None

View File

@@ -1074,6 +1074,24 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("streaming"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""