Streaming SFT support (#3101)

* working

* fixes

* deprecate --iterable; cleanup

* pretrain_multipack_buffer_size -> streaming_multipack_buffer_size

* improvements

* tests

* remove unused

* docs, examples

* nit

* nit

* add val_set_size validation

* val

* nit

* min

* coderabbito

* cleanup

* nit

* add depr warning, cleanup

* nit

* fix test, fix quarto

* fix

* review comments

* review comments

* fix
This commit is contained in:
Dan Saunders
2025-09-02 12:08:44 -04:00
committed by GitHub
parent 0094a2d744
commit 231a67e70b
24 changed files with 849 additions and 283 deletions

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.streaming
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,6 +272,7 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd

120
docs/streaming.qmd Normal file
View File

@@ -0,0 +1,120 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -0,0 +1,50 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -0,0 +1,57 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,55 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
},
)

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
)
return

View File

@@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -1,18 +1,17 @@
"""Module containing Dataset functionality"""
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -86,133 +85,3 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__(
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning("Empty text requested for tokenization.")
LOG.warning_once("Empty text requested for tokenization.")
return empty
result = self.tokenizer(

View File

@@ -1,11 +1,17 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
"""Shared axolotl collators for multipacking, mamba, multimodal."""
from .batching import ( # noqa: F401
from .batching import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401
from .mamba import MambaDataCollator
__all__ = [
"DataCollatorForSeq2Seq",
"BatchSamplerDataCollatorForSeq2Seq",
"V2BatchSamplerDataCollatorForSeq2Seq",
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
"MambaDataCollator",
]

View File

@@ -1,8 +1,8 @@
"""Init for `axolotl.utils.data` module."""
from axolotl.utils.data.pretraining import (
encode_pretraining,
wrap_pretraining_dataset,
from axolotl.utils.data.streaming import (
encode_streaming,
wrap_streaming_dataset,
)
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
@@ -12,8 +12,8 @@ from axolotl.utils.data.sft import (
from axolotl.utils.data.utils import md5
__all__ = [
"encode_pretraining",
"wrap_pretraining_dataset",
"encode_streaming",
"wrap_streaming_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",

View File

@@ -9,13 +9,14 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
from axolotl.prompters import Prompter
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
@@ -48,7 +49,6 @@ def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -56,23 +56,19 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
if cfg.streaming or cfg.pretraining_dataset:
return _prepare_streaming_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
@@ -83,7 +79,6 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -93,7 +88,6 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -128,22 +122,40 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, total_num_steps, prompters
def _prepare_pretraining_dataset(
def _prepare_streaming_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for pretraining mode.
Prepare dataset for streaming mode.
Note: Pre-training datasets are streamed from the HuggingFace Hub.
Note: Streaming datasets are loaded incrementally from the source.
"""
# Extract pretraining dataset configuration
pretraining_config = _extract_pretraining_config(cfg)
if cfg.pretraining_dataset:
dataset_config = _extract_pretraining_config(cfg)
train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)
elif cfg.sample_packing:
# TODO(djsaunde): Implement for multiple datasets
dataset_config = DictDefault(cfg.datasets[0])
# Load streaming dataset for training
train_dataset = _load_pretraining_dataset(pretraining_config, cfg, tokenizer)
# Ensure we have a split set - default to 'train' if not specified
if not hasattr(dataset_config, "split") or not dataset_config.split:
dataset_config.split = "train"
train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)
else:
# Use legacy loading function for non-packed streaming datasets
train_dataset, eval_dataset, prompters = _load_and_prepare_datasets(
tokenizer,
cfg,
split="train",
processor=processor,
streaming=True,
)
# Return early for non-packed streaming datasets
total_num_steps = cfg.max_steps if cfg.max_steps else -1
return train_dataset, eval_dataset, total_num_steps, prompters
# Load evaluation dataset if specified
eval_dataset = None
@@ -153,14 +165,12 @@ def _prepare_pretraining_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
streaming=False,
)
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
# For pretraining, we return max_steps directly from config
return train_dataset, eval_dataset, cfg.max_steps, []
# For streaming, we return max_steps directly from config or -1 if not set
total_num_steps = cfg.max_steps if cfg.max_steps else -1
return train_dataset, eval_dataset, total_num_steps, []
def _extract_pretraining_config(cfg: DictDefault) -> DictDefault:
@@ -192,7 +202,7 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault:
)
def _load_pretraining_dataset(
def _load_streaming_dataset(
pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> IterableDataset:
"""Load and prepare a streaming dataset for pretraining."""
@@ -227,15 +237,11 @@ def _load_pretraining_dataset(
iter_dataset = iter_dataset.skip(pretraining_config["skip"])
# Wrap the dataset for pretraining
train_dataset = wrap_pretraining_dataset(
train_dataset = wrap_streaming_dataset(
iter_dataset,
tokenizer,
cfg,
dataset_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
)
# Format for PyTorch
@@ -256,7 +262,7 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
streaming: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -265,7 +271,7 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
streaming: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
@@ -296,7 +302,7 @@ def _load_tokenized_prepared_datasets(
tokenizer,
split,
processor,
preprocess_iterable,
streaming,
)
return dataset, prompters
@@ -308,7 +314,7 @@ def _load_raw_datasets(
tokenizer: PreTrainedTokenizer,
split: str,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
streaming: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -329,7 +335,7 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
preprocess_iterable=preprocess_iterable,
streaming=streaming,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -337,7 +343,7 @@ def _load_raw_datasets(
# Merge datasets
dataset = merge_datasets(datasets, cfg)
if not cfg.skip_prepare_dataset:
if not cfg.skip_prepare_dataset and not streaming:
if split == "test" and cfg.eval_sequence_len:
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
@@ -361,19 +367,19 @@ def _load_and_process_single_dataset(
split: str,
seed: int,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
streaming: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
dataset_config, cfg.hf_use_auth_token, streaming=streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, DatasetDict):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -479,7 +485,7 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
streaming: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -488,7 +494,7 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
streaming: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -499,7 +505,7 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
streaming=streaming,
)
# Apply dataset sharding if configured using shared function

View File

@@ -236,11 +236,9 @@ def _load_from_local_path(
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
load_dataset_kwargs["streaming"] = False
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
load_dataset_kwargs["streaming"] = False
return load_dataset(
dataset_type,
data_files=dataset_config.path,

View File

@@ -1,4 +1,4 @@
"""data handling specific to pretraining"""
"""Data handling specific to streaming datasets."""
import functools
from collections import defaultdict
@@ -17,10 +17,10 @@ from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
def encode_pretraining(
def encode_streaming(
examples: Dict[str, List],
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
@@ -176,45 +176,57 @@ def encode_pretraining(
return ret
def wrap_pretraining_dataset(
def wrap_streaming_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
# For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure
# attention isolation between packed sequences
multipack_attn = (
True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn
)
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
pad_to_multiple_of=cfg.sequence_len,
multipack_attn=multipack_attn,
)
encode = functools.partial(
encode_packed_pretraining,
encode_packed_streaming,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
max_seq_length=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
multipack_attn=multipack_attn,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
# again
cfg.micro_batch_size = 1
else:
# NOTE: This is not reachable for SFT datasets since we use the pre-existing
# loading function for non-packed streaming datasets. Refer to
# _prepare_streaming_datasets in sft.py for that code path.
text_column = (
getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text"
)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
encode_streaming,
tokenizer=tokenizer,
max_tokens=cfg.sequence_len,
text_column=text_column,
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = dataset.shuffle(
seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size
)
else:
LOG.debug("NOT shuffling merged pretraining datasets")
@@ -232,14 +244,13 @@ def wrap_pretraining_dataset(
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
batch_size=cfg.streaming_multipack_buffer_size,
remove_columns=remove_columns,
)
return dataset
def encode_packed_pretraining(
def encode_packed_streaming(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
@@ -274,8 +285,6 @@ def encode_packed_pretraining(
for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:

View File

@@ -190,12 +190,21 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return dataset
drop_long = functools.partial(
drop_long_seq,

View File

@@ -475,12 +475,6 @@ class AxolotlInputConfig(
},
)
multipack_real_batches: bool | None = None
pretraining_sample_concatenation: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to concatenate samples during pretraining",
},
)
batch_flattening: Literal["auto"] | bool | None = Field(
default=None,
@@ -495,13 +489,34 @@ class AxolotlInputConfig(
pose_max_context_len: int | None = None
pose_num_chunks: int | None = None
pretrain_multipack_buffer_size: int | None = 10_000
# Deprecated: Use streaming_multipack_buffer_size instead
pretrain_multipack_buffer_size: int | None = Field(
default=None,
deprecated="Deprecated in v0.13.0, will be removed in v0.14.0. Use streaming_multipack_buffer_size instead",
)
pretrain_multipack_attn: bool | None = Field(
default=True,
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
},
)
pretraining_sample_concatenation: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to concatenate samples during pretraining",
},
)
streaming: bool | None = Field(
default=None,
json_schema_extra={"description": "Use streaming mode for loading datasets"},
)
streaming_multipack_buffer_size: int | None = Field(
default=10_000,
json_schema_extra={
"description": "Buffer size for multipack streaming datasets"
},
)
xformers_attention: bool | None = Field(
default=None,
@@ -1264,3 +1279,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["dataset_processes"] = get_default_process_count()
return data
@model_validator(mode="before")
@classmethod
def check_deduplication_with_streaming(cls, data):
if data.get("dataset_exact_deduplication") and (
data.get("streaming") or data.get("pretraining_dataset")
):
raise NotImplementedError(
"dataset_exact_deduplication is not available for streaming datasets. "
)
return data

View File

@@ -60,6 +60,20 @@ class DatasetValidationMixin:
raise ValueError("either datasets or pretraining_dataset is required")
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_streaming_deprecation(cls, data):
# TODO(djsaunde): remove this check + implement change for 0.13.0 release
if data.get("pretraining_dataset") and not data.get("streaming"):
LOG.warning(
"Setting `pretraining_dataset` without explicitly setting `streaming: "
"true` is deprecated. In a future release, streaming will not be "
"automatically enabled when using pretraining_dataset. Please "
"explicitly set `streaming: true` in your configuration to maintain "
"current behavior."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_ds_auth(cls, data):
@@ -340,6 +354,30 @@ class TrainingValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_multipack_buffer_size(cls, data):
if data.get("pretrain_multipack_buffer_size") and not data.get(
"streaming_multipack_buffer_size"
):
LOG.warning(
"`pretrain_multipack_buffer_size` is deprecated in v0.13.0, will be "
"removed in v0.14.0. Use `streaming_multipack_buffer_size` instead."
)
data["streaming_multipack_buffer_size"] = data[
"pretrain_multipack_buffer_size"
]
del data["pretrain_multipack_buffer_size"]
elif data.get("pretrain_multipack_buffer_size") and data.get(
"streaming_multipack_buffer_size"
):
raise ValueError(
"pretrain_multipack_buffer_size is deprecated, use "
"streaming_multipack_buffer_size; both are set, please remove the "
"deprecated pretrain_multipack_buffer_size setting"
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
@@ -1074,6 +1112,50 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_w_val_set_size(cls, data):
if data.get("pretraining_dataset") and data.get("val_set_size"):
raise ValueError(
"val_set_size is not supported with pretraining_dataset. "
"Use test_datasets to specify evaluation datasets for pretraining."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_val_set_size(cls, data):
if data.get("streaming") and data.get("val_set_size"):
raise ValueError(
"val_set_size is not supported with streaming datasets. "
"Use test_datasets to specify evaluation datasets when streaming is enabled."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_max_steps(cls, data):
if data.get("streaming") and not data.get("max_steps"):
raise ValueError(
"max_steps must be set when using streaming datasets. "
"Trainer cannot infer dataset length for iterable datasets."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_multiple_datasets(cls, data):
if (
data.get("streaming")
and data.get("sample_packing")
and data.get("datasets")
and len(data.get("datasets")) > 1
):
raise NotImplementedError(
"Sample packing with multiple streaming datasets is not yet supported"
)
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"chat_template": "qwen3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -0,0 +1,73 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_streaming_dataset(self, temp_dir, sample_packing):
"""Test streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
3.0,
"Train Loss (%s) is too high",
)

View File

@@ -6,7 +6,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data import encode_streaming, md5
from tests.hf_offline_utils import enable_hf_offline
@@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
result = encode_streaming(examples, self.tokenizer, self.max_tokens)
self.assertEqual(len(result["input_ids"]), 3)

View File

@@ -1,16 +1,11 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -35,43 +30,6 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
cfg = DictDefault(

View File

@@ -9,7 +9,7 @@ import torch
from datasets import IterableDataset
from torch.utils.data import DataLoader
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
from axolotl.utils.dict import DictDefault
@@ -77,14 +77,11 @@ class TestPretrainingPacking:
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
train_dataset = wrap_streaming_dataset(
dataset,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
trainer_loader = DataLoader(

238
tests/test_streaming.py Normal file
View File

@@ -0,0 +1,238 @@
"""Test streaming configuration and data loading functionality."""
import unittest
from unittest.mock import Mock, patch
from datasets import IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
from axolotl.utils.config import validate_config
class TestStreamingConfig(unittest.TestCase):
"""Test streaming configuration and deprecation handling."""
def test_streaming_multipack_buffer_size_deprecation(self):
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
# Test with old config name
cfg_old = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
validated_cfg = validate_config(cfg_old)
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
self.assertIsNone(
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
)
def test_streaming_multipack_buffer_size_new(self):
"""Test that new streaming_multipack_buffer_size works correctly."""
cfg_new = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
validated_cfg = validate_config(cfg_new)
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
def test_both_buffer_sizes_raises_error(self):
"""Test that having both old and new buffer size configs raises an error."""
cfg_both = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertRaises(ValueError) as cm:
validate_config(cfg_both)
self.assertIn("both are set", str(cm.exception))
class TestStreamingDatasetPreparation(unittest.TestCase):
"""Test dataset preparation with streaming configuration."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
"""Test that streaming=True triggers streaming dataset preparation."""
cfg = DictDefault(
{
"streaming": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
"""Test that pretraining_dataset triggers streaming dataset preparation."""
cfg = DictDefault(
{
"pretraining_dataset": "test/dataset",
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
"""Test that without streaming, standard dataset preparation is used."""
cfg = DictDefault(
{
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_standard.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
class TestStreamingWithSamplePacking(unittest.TestCase):
"""Test streaming dataset preparation with sample packing."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._load_streaming_dataset")
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
"""Test that streaming SFT with sample_packing sets default split."""
cfg = DictDefault(
{
"streaming": True,
"sample_packing": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
}
)
mock_load_streaming.return_value = Mock(spec=IterableDataset)
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
_prepare_streaming_dataset(cfg, self.tokenizer, None)
# Check that the dataset config has split set to 'train'
call_args = mock_load_streaming.call_args
dataset_config = call_args[0][0]
self.assertEqual(dataset_config.split, "train")
def test_multipack_attn_forced_true_for_sft(self):
"""Test that multipack_attn is forced to True for SFT with sample packing."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be overridden for SFT
"pretraining_dataset": None, # This makes it SFT
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=True was used in the collator
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertTrue(call_kwargs["multipack_attn"])
def test_multipack_attn_respects_config_for_pretraining(self):
"""Test that multipack_attn respects config for pretraining datasets."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be respected for pretraining
"pretraining_dataset": "test/dataset", # This makes it pretraining
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=False was used (respecting config)
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertFalse(call_kwargs["multipack_attn"])
if __name__ == "__main__":
unittest.main()