From 231a67e70bbfc095fc94e057537412cf57a472cf Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 2 Sep 2025 12:08:44 -0400 Subject: [PATCH] Streaming SFT support (#3101) * working * fixes * deprecate --iterable; cleanup * pretrain_multipack_buffer_size -> streaming_multipack_buffer_size * improvements * tests * remove unused * docs, examples * nit * nit * add val_set_size validation * val * nit * min * coderabbito * cleanup * nit * add depr warning, cleanup * nit * fix test, fix quarto * fix * review comments * review comments * fix --- _quarto.yml | 3 +- docs/streaming.qmd | 120 +++++++++ examples/streaming/README.md | 50 ++++ examples/streaming/pretrain.yaml | 57 +++++ examples/streaming/sft.yaml | 55 ++++ src/axolotl/cli/args.py | 8 +- src/axolotl/cli/preprocess.py | 12 +- src/axolotl/common/datasets.py | 2 - src/axolotl/datasets.py | 145 +---------- src/axolotl/prompt_tokenizers.py | 2 +- src/axolotl/utils/collators/__init__.py | 16 +- src/axolotl/utils/data/__init__.py | 10 +- src/axolotl/utils/data/sft.py | 92 +++---- src/axolotl/utils/data/shared.py | 2 - .../data/{pretraining.py => streaming.py} | 59 +++-- src/axolotl/utils/data/utils.py | 11 +- src/axolotl/utils/schemas/config.py | 40 ++- src/axolotl/utils/schemas/validation.py | 82 ++++++ tests/e2e/integrations/test_kd.py | 2 +- tests/e2e/test_streaming.py | 73 ++++++ tests/test_data.py | 4 +- tests/test_packed_dataset.py | 42 ---- tests/test_packed_pretraining.py | 7 +- tests/test_streaming.py | 238 ++++++++++++++++++ 24 files changed, 849 insertions(+), 283 deletions(-) create mode 100644 docs/streaming.qmd create mode 100644 examples/streaming/README.md create mode 100644 examples/streaming/pretrain.yaml create mode 100644 examples/streaming/sft.yaml rename src/axolotl/utils/data/{pretraining.py => streaming.py} (86%) create mode 100644 tests/e2e/test_streaming.py create mode 100644 tests/test_streaming.py diff --git a/_quarto.yml b/_quarto.yml index 934d393cb..3ffb0e627 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -153,7 +153,7 @@ quartodoc: - utils.distributed - utils.dict - utils.optimizers.adopt - - utils.data.pretraining + - utils.data.streaming - utils.data.sft - utils.quantization - title: Schemas @@ -272,6 +272,7 @@ website: contents: - docs/batch_vs_grad.qmd - docs/dataset_preprocessing.qmd + - docs/streaming.qmd - docs/multipack.qmd - docs/mixed_precision.qmd - docs/optimizers.qmd diff --git a/docs/streaming.qmd b/docs/streaming.qmd new file mode 100644 index 000000000..2a233a4fc --- /dev/null +++ b/docs/streaming.qmd @@ -0,0 +1,120 @@ +--- +title: Streaming Datasets +description: How to use streaming mode for large-scale datasets and memory-efficient training +order: 10 +--- + +Streaming enables memory-efficient training with large datasets by loading data +incrementally rather than loading the entire dataset into memory at once. + +Use streaming when: + +- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora) +- You want to start training immediately without preprocessing the entire dataset + +Streaming works with both remote and locally stored datasets! + +::: {.callout-note} +Streaming currently only supports a single dataset. Multi-dataset support will be added soon. +::: + + +## Configuration + +### Basic Streaming + +Enable streaming mode by setting the `streaming` flag: + +```yaml +streaming: true +``` + +### Pretraining with Streaming + +For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`: + +```yaml +pretraining_dataset: + - path: HuggingFaceFW/fineweb-edu + type: pretrain + text_column: text + split: train + +# Optionally, enable sample packing +streaming_multipack_buffer_size: 10000 +sample_packing: true +``` + +### SFT with Streaming + +For supervised fine-tuning with streaming: + +```yaml +streaming: true +datasets: + - path: tatsu-lab/alpaca + type: alpaca + split: train + +# Optionally, enable sample packing +streaming_multipack_buffer_size: 10000 +sample_packing: true +``` + +## Configuration Options + +### `streaming_multipack_buffer_size` + +Controls the buffer size for multipack streaming (default: 10,000). This determines how +many samples are buffered before packing. Larger buffers can improve packing efficiency +but use more memory. + +### `shuffle_merged_datasets` + +When enabled, shuffles the streaming dataset using the buffer. This requires additional +memory for the shuffle buffer. + +## Sample Packing with Streaming + +Sample packing is supported for streaming datasets. When enabled, multiple samples are +packed into a single sequence to maximize GPU utilization: + +```yaml +sample_packing: true +streaming_multipack_buffer_size: 10000 + +# For SFT: attention is automatically isolated between packed samples +# For pretraining: control with pretrain_multipack_attn +pretrain_multipack_attn: true # prevent cross-attention between packed samples +``` + +For more information, see our [documentation](multipack.qmd) on multipacking. + +## Important Considerations + +### Memory Usage + +While streaming reduces memory usage compared to loading entire datasets, you still need +to consider: + +- You can control the memory usage by adjusting `streaming_multipack_buffer_size` +- Sample packing requires buffering multiple samples +- Shuffling requires additional memory for the shuffle buffer + +### Performance + +- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly +- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively +- Consider using `axolotl preprocess` for smaller or more frequently used datasets + +### Evaluation Datasets + +Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're +loaded normally even when training uses streaming. + +## Examples + +See the `examples/streaming/` directory for complete configuration examples: + +- `pretrain.yaml`: Pretraining with streaming dataset +- `sft.yaml`: Supervised fine-tuning with streaming diff --git a/examples/streaming/README.md b/examples/streaming/README.md new file mode 100644 index 000000000..cdbb5baea --- /dev/null +++ b/examples/streaming/README.md @@ -0,0 +1,50 @@ +# Streaming Dataset Examples + +This directory contains example configurations for using Axolotl's streaming dataset +functionality, which enables memory-efficient training with large datasets. + +## Examples + +Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no +`axolotl preprocess` required! + +### Pretraining (`pretrain.yaml`) + +Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset +with SmolLM2-135M. + +- Uses `pretraining_dataset` configuration for automatic streaming +- Multipack attention control to prevent cross-attention between packed sequences +- Buffer size configuration for memory management + +### SFT (`sft.yaml`) + +Shows how to use streaming for supervised fine-tuning with the Alpaca dataset. + +- Explicit `streaming: true` flag for SFT datasets +- Memory-efficient training on instruction datasets +- Evaluation datasets are currently not streamed + +## Key Configuration Options + +### `streaming` +- Enables streaming mode for standard datasets +- Automatically enabled for `pretraining_dataset` + +### `streaming_multipack_buffer_size` +- Controls buffer size for sample packing (default: 10,000) +- Larger values improve packing efficiency but use more memory +- Adjust based on available memory + +### `shuffle_merged_datasets` +- Enables shuffling of streaming datasets +- Requires additional memory for shuffle buffer + +### `sample_packing` +- Packs multiple samples into single sequences +- Minimize per-step padding tokens + +## Performance Tips + +- Download small / frequently-used datasets locally for better performance +- Larger buffer sizes improve packing efficiency diff --git a/examples/streaming/pretrain.yaml b/examples/streaming/pretrain.yaml new file mode 100644 index 000000000..bc8edefd6 --- /dev/null +++ b/examples/streaming/pretrain.yaml @@ -0,0 +1,57 @@ +base_model: HuggingFaceTB/SmolLM2-135M + +# Streaming pretraining configuration +pretraining_dataset: + - path: HuggingFaceFW/fineweb-edu + name: sample-10BT + type: pretrain + text_column: text + split: train + +# Streaming-specific settings +streaming_multipack_buffer_size: 10000 +shuffle_merged_datasets: true + +# Training configuration +max_steps: 1000 +output_dir: ./outputs/smollm2-135m-pretrain-streaming + +# Sequence and packing settings +sequence_len: 1024 +sample_packing: true +pretrain_multipack_attn: true # Prevent cross-attention between packed sequences +flash_attention: true + +# Batch size settings +gradient_accumulation_steps: 8 +micro_batch_size: 1 + +# Optimizer and scheduler +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 5e-4 +warmup_ratio: 0.1 +weight_decay: 0.01 + +# Precision and performance +bf16: auto +tf32: true + +# Logging and checkpointing +logging_steps: 10 +save_strategy: steps +save_steps: 250 +save_total_limit: 3 + +# Weights & Biases (optional) +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# Special tokens +special_tokens: + pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/streaming/sft.yaml b/examples/streaming/sft.yaml new file mode 100644 index 000000000..47b9f493f --- /dev/null +++ b/examples/streaming/sft.yaml @@ -0,0 +1,55 @@ +base_model: HuggingFaceTB/SmolLM2-135M + +# Dataset configuration +datasets: + - path: tatsu-lab/alpaca + type: alpaca + split: train + +# Streaming-specific settings +streaming: true +streaming_multipack_buffer_size: 10000 +shuffle_merged_datasets: true + +# Training configuration +max_steps: 1000 +output_dir: ./outputs/smollm2-135m-sft-streaming + +# Sequence and packing settings +sequence_len: 1024 +sample_packing: true +flash_attention: true + +# Batch size settings +gradient_accumulation_steps: 4 +micro_batch_size: 1 + +# Optimizer and scheduler +optimizer: adamw_torch +lr_scheduler: cosine +learning_rate: 2e-4 +warmup_ratio: 0.1 +weight_decay: 0.0 + +# Precision and performance +bf16: auto +tf32: true + +# Logging and checkpointing +logging_steps: 10 +save_strategy: steps +save_steps: 100 +save_total_limit: 3 + +# Weights & Biases (optional) +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +# Special tokens +special_tokens: + pad_token: "<|endoftext|>" + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 9bb544aff..396e9a8af 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -14,9 +14,13 @@ class PreprocessCliArgs: prompter: Optional[str] = field(default=None) download: Optional[bool] = field(default=True) iterable: Optional[bool] = field( - default=None, + default=False, metadata={ - "help": "Use IterableDataset for streaming processing of large datasets" + "help": ( + "Deprecated in v0.13.0, will be removed in v0.14.0. For streaming " + "datasets, use 'axolotl train' and set 'streaming: true' in your YAML " + "config, or pass --streaming instead in the CLI." + ) }, ) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index ff4551c64..6c05a55f1 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: check_accelerate_default_config() check_user_token() + if cli_args.iterable: + LOG.error( + "The --iterable CLI argument for 'axolotl preprocess' is no longer " + "supported. For training, set 'streaming: true' in your YAML config or " + "pass '--streaming' in your 'axolotl train' command for on-the-fly " + "preprocessing." + ) + return + for key in ["skip_prepare_dataset", "pretraining_dataset"]: if cfg.get(key): LOG.error( - f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead." + f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl " + "train' CLI directly instead." ) return diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index e7433e3c2..8d7758e66 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -55,13 +55,11 @@ def load_datasets( """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None - preprocess_iterable = getattr(cli_args, "iterable", False) train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets( cfg, tokenizer, processor=processor, - preprocess_iterable=preprocess_iterable, ) if ( diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index b8f9484bc..20acb8521 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -1,18 +1,17 @@ -"""Module containing Dataset functionality""" +""" +Module containing dataset functionality. + +We want this to be a wrapper for an existing dataset that we have loaded. Lets use the +concept of middlewares to wrap each dataset. We'll use the collators later on to pad the +datasets. +""" -import torch from datasets import Dataset, IterableDataset from axolotl.utils.logging import get_logger from .prompt_tokenizers import PromptTokenizingStrategy -# We want this to be a wrapper for an existing dataset that we have loaded -# lets use the concept of middlewares to wrap each dataset, for example -# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])) -# let's check to ensure we don't truncate an item in the middle, we'll use -# the collators later on to pad the datasets - LOG = get_logger(__name__) @@ -86,133 +85,3 @@ def wrap_dataset_for_tokenized_prompt( **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) - - -# TODO this isn't the best since it can't interleave datasets -class ConstantLengthDataset(IterableDataset): - """Iterable dataset that returns constant length chunks of tokens from stream of - text files. - - Args: - tokenizer: The processor used for processing the data. - dataset: Dataset with text files. - seq_length: Length of token sequences to return. - """ - - def __init__( - self, - tokenizer, - datasets, - seq_length=2048, - ): - self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id - self.datasets: list[IterableDataset] = datasets - self.seq_length = seq_length - - vocab_size = len(tokenizer.get_vocab()) - - if vocab_size <= torch.iinfo(torch.int16).max: - self.tokens_dtype = torch.int16 - elif vocab_size <= torch.iinfo(torch.int32).max: - self.tokens_dtype = torch.int32 - else: - self.tokens_dtype = torch.int64 - - def __iter__(self): - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - for dataset in self.datasets: - idx = 0 - iterator = iter(dataset) - more_examples = True - while more_examples: - try: - example = next(iterator) - idx += 1 - except StopIteration: - more_examples = False - example = None - - add_concat_token = False - if example: - example_len = len(example["input_ids"]) - add_concat_token = example["input_ids"][-1] != self.concat_token_id - else: - example_len = 0 - - if not example_len or ( - buffer_len + int(add_concat_token) + example_len > self.seq_length - ): - if buffer["input_ids"]: - input_ids = torch.cat(buffer["input_ids"], dim=-1)[ - : self.seq_length - ] - attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ - : self.seq_length - ] - position_ids = torch.cat(buffer["position_ids"], dim=-1)[ - : self.seq_length - ] - labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] - if labels.size() == input_ids.size() and ( - attention_mask.size() == input_ids.size() - ): - yield { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "position_ids": position_ids, - } - else: - LOG.warning( - "Dropping batch due to tensor size mismatch " - f"input_ids: {input_ids.size()}, " - f"labels: {labels.size()}, " - f"attention_mask: {attention_mask.size()}" - ) - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - "position_ids": [], - } - buffer_len = 0 - idx = 1 - - if example: - # FIXME - # just going to drop data points that are too long - if len(example["input_ids"]) <= self.seq_length: - input_ids = example["input_ids"] - attention_mask = example["attention_mask"] - labels = example["labels"] - - if add_concat_token: - input_ids.append(self.concat_token_id) - attention_mask.append(1) - labels.append(self.concat_token_id) - - input_ids_with_concat = torch.tensor( - input_ids, dtype=self.tokens_dtype - ) - attention_mask_with_concat = torch.tensor( - [idx * m for m in attention_mask], dtype=torch.int16 - ) - labels_with_concat = torch.tensor( - labels, dtype=self.tokens_dtype - ) - position_ids = torch.arange( - len(input_ids), dtype=self.tokens_dtype - ) - - buffer["input_ids"].append(input_ids_with_concat) - buffer["attention_mask"].append(attention_mask_with_concat) - buffer["labels"].append(labels_with_concat) - buffer["position_ids"].append(position_ids) - buffer_len += len(input_ids) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 2bf9ec763..a7bd963f8 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC): ) -> BatchEncoding: empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) if not prompt: - LOG.warning("Empty text requested for tokenization.") + LOG.warning_once("Empty text requested for tokenization.") return empty result = self.tokenizer( diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 8c60f223c..d5e6ad17d 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -1,11 +1,17 @@ -""" -shared axolotl collators for multipack, mamba, multimodal -""" +"""Shared axolotl collators for multipacking, mamba, multimodal.""" -from .batching import ( # noqa: F401 +from .batching import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) -from .mamba import MambaDataCollator # noqa: F401 +from .mamba import MambaDataCollator + +__all__ = [ + "DataCollatorForSeq2Seq", + "BatchSamplerDataCollatorForSeq2Seq", + "V2BatchSamplerDataCollatorForSeq2Seq", + "PretrainingBatchSamplerDataCollatorForSeq2Seq", + "MambaDataCollator", +] diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py index d162a7d0b..788f13638 100644 --- a/src/axolotl/utils/data/__init__.py +++ b/src/axolotl/utils/data/__init__.py @@ -1,8 +1,8 @@ """Init for `axolotl.utils.data` module.""" -from axolotl.utils.data.pretraining import ( - encode_pretraining, - wrap_pretraining_dataset, +from axolotl.utils.data.streaming import ( + encode_streaming, + wrap_streaming_dataset, ) from axolotl.utils.data.rl import prepare_preference_datasets from axolotl.utils.data.sft import ( @@ -12,8 +12,8 @@ from axolotl.utils.data.sft import ( from axolotl.utils.data.utils import md5 __all__ = [ - "encode_pretraining", - "wrap_pretraining_dataset", + "encode_streaming", + "wrap_streaming_dataset", "prepare_preference_datasets", "get_dataset_wrapper", "prepare_datasets", diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 2ae7d9052..28732e01d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -9,13 +9,14 @@ from datasets import ( Dataset, DatasetDict, IterableDataset, + IterableDatasetDict, load_dataset, ) from transformers import PreTrainedTokenizer, ProcessorMixin from axolotl.prompters import Prompter from axolotl.utils.data.lock import FileLockLoader -from axolotl.utils.data.pretraining import wrap_pretraining_dataset +from axolotl.utils.data.streaming import wrap_streaming_dataset from axolotl.utils.data.shared import ( create_train_validation_split, datasets_with_name_generator, @@ -48,7 +49,6 @@ def prepare_datasets( cfg: DictDefault, tokenizer: PreTrainedTokenizer, processor: ProcessorMixin | None = None, - preprocess_iterable: bool = False, ) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]: """Prepare training and evaluation datasets based on configuration. @@ -56,23 +56,19 @@ def prepare_datasets( cfg: Dictionary mapping `axolotl` config keys to values. tokenizer: Tokenizer to use for processing text. processor: Optional processor for multimodal datasets. - preprocess_iterable: Whether to use iterable preprocessing. Returns: Tuple of (train_dataset, eval_dataset, total_steps, prompters). """ - if cfg.pretraining_dataset: - return _prepare_pretraining_dataset( - cfg, tokenizer, processor, preprocess_iterable - ) - return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable) + if cfg.streaming or cfg.pretraining_dataset: + return _prepare_streaming_dataset(cfg, tokenizer, processor) + return _prepare_standard_dataset(cfg, tokenizer, processor) def _prepare_standard_dataset( cfg: DictDefault, tokenizer: PreTrainedTokenizer, processor: ProcessorMixin | None, - preprocess_iterable: bool, ) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]: """Prepare standard (non-pretraining) datasets.""" @@ -83,7 +79,6 @@ def _prepare_standard_dataset( cfg, split="train", processor=processor, - preprocess_iterable=preprocess_iterable, ) # Overwrite eval_dataset if test data exists @@ -93,7 +88,6 @@ def _prepare_standard_dataset( cfg, split="test", processor=processor, - preprocess_iterable=preprocess_iterable, ) return train_dataset, eval_dataset, prompters @@ -128,22 +122,40 @@ def _prepare_standard_dataset( return train_dataset, eval_dataset, total_num_steps, prompters -def _prepare_pretraining_dataset( +def _prepare_streaming_dataset( cfg: DictDefault, tokenizer: PreTrainedTokenizer, processor: ProcessorMixin | None, - preprocess_iterable: bool, ) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]: """ - Prepare dataset for pretraining mode. + Prepare dataset for streaming mode. - Note: Pre-training datasets are streamed from the HuggingFace Hub. + Note: Streaming datasets are loaded incrementally from the source. """ - # Extract pretraining dataset configuration - pretraining_config = _extract_pretraining_config(cfg) + if cfg.pretraining_dataset: + dataset_config = _extract_pretraining_config(cfg) + train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + elif cfg.sample_packing: + # TODO(djsaunde): Implement for multiple datasets + dataset_config = DictDefault(cfg.datasets[0]) - # Load streaming dataset for training - train_dataset = _load_pretraining_dataset(pretraining_config, cfg, tokenizer) + # Ensure we have a split set - default to 'train' if not specified + if not hasattr(dataset_config, "split") or not dataset_config.split: + dataset_config.split = "train" + train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer) + else: + # Use legacy loading function for non-packed streaming datasets + train_dataset, eval_dataset, prompters = _load_and_prepare_datasets( + tokenizer, + cfg, + split="train", + processor=processor, + streaming=True, + ) + + # Return early for non-packed streaming datasets + total_num_steps = cfg.max_steps if cfg.max_steps else -1 + return train_dataset, eval_dataset, total_num_steps, prompters # Load evaluation dataset if specified eval_dataset = None @@ -153,14 +165,12 @@ def _prepare_pretraining_dataset( cfg, split="test", processor=processor, - preprocess_iterable=preprocess_iterable, + streaming=False, ) - if cfg.dataset_exact_deduplication: - LOG.info("Deduplication not available for pretrained datasets") - - # For pretraining, we return max_steps directly from config - return train_dataset, eval_dataset, cfg.max_steps, [] + # For streaming, we return max_steps directly from config or -1 if not set + total_num_steps = cfg.max_steps if cfg.max_steps else -1 + return train_dataset, eval_dataset, total_num_steps, [] def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: @@ -192,7 +202,7 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault: ) -def _load_pretraining_dataset( +def _load_streaming_dataset( pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer ) -> IterableDataset: """Load and prepare a streaming dataset for pretraining.""" @@ -227,15 +237,11 @@ def _load_pretraining_dataset( iter_dataset = iter_dataset.skip(pretraining_config["skip"]) # Wrap the dataset for pretraining - train_dataset = wrap_pretraining_dataset( + train_dataset = wrap_streaming_dataset( iter_dataset, tokenizer, cfg, dataset_wrapper_partial, - max_tokens=cfg.sequence_len, - batch_size=cfg.micro_batch_size, - seed=cfg.seed, - buffer_size=cfg.pretrain_multipack_buffer_size or 10_000, ) # Format for PyTorch @@ -256,7 +262,7 @@ def _load_tokenized_prepared_datasets( cfg: DictDefault, split: Literal["train", "test"] = "train", processor: ProcessorMixin | None = None, - preprocess_iterable: bool = False, + streaming: bool = False, ) -> tuple[Dataset | DatasetDict, list[Prompter | None]]: """Load or create tokenized and prepared datasets for training or testing. @@ -265,7 +271,7 @@ def _load_tokenized_prepared_datasets( cfg: Configuration object. split: Dataset split to load ('train' or 'test'). processor: Optional processor for multimodal datasets. - preprocess_iterable: Whether to use iterable preprocessing. + streaming: Whether to use iterable preprocessing. Returns: Tuple of (dataset, prompters list). @@ -296,7 +302,7 @@ def _load_tokenized_prepared_datasets( tokenizer, split, processor, - preprocess_iterable, + streaming, ) return dataset, prompters @@ -308,7 +314,7 @@ def _load_raw_datasets( tokenizer: PreTrainedTokenizer, split: str, processor: ProcessorMixin | None = None, - preprocess_iterable: bool = False, + streaming: bool = False, ) -> tuple[Dataset, list[Prompter | None]]: """Load, process, merge, and save raw datasets.""" LOG.info("Loading raw datasets...", main_process_only=False) @@ -329,7 +335,7 @@ def _load_raw_datasets( split=split, seed=cfg.seed, processor=processor, - preprocess_iterable=preprocess_iterable, + streaming=streaming, ) datasets.append(dataset_wrapper) prompters.append(dataset_prompter) @@ -337,7 +343,7 @@ def _load_raw_datasets( # Merge datasets dataset = merge_datasets(datasets, cfg) - if not cfg.skip_prepare_dataset: + if not cfg.skip_prepare_dataset and not streaming: if split == "test" and cfg.eval_sequence_len: dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) else: @@ -361,19 +367,19 @@ def _load_and_process_single_dataset( split: str, seed: int, processor: ProcessorMixin | None = None, - preprocess_iterable: bool = False, + streaming: bool = False, ) -> tuple[Dataset | IterableDataset, Prompter | None]: """Load and process a single dataset based on the passed config.""" # Load the dataset dataset = load_dataset_with_config( - dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable + dataset_config, cfg.hf_use_auth_token, streaming=streaming ) # Parse dataset type d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) # Select the appropriate split - if isinstance(dataset, DatasetDict): + if isinstance(dataset, (DatasetDict, IterableDatasetDict)): if dataset_config.split and dataset_config.split in dataset: dataset = dataset[dataset_config.split] elif split in dataset: @@ -479,7 +485,7 @@ def _load_and_prepare_datasets( cfg: DictDefault, split: Literal["train", "test"] = "train", processor: ProcessorMixin | None = None, - preprocess_iterable: bool = False, + streaming: bool = False, ) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]: """Load and prepare datasets with optional validation split and sharding. @@ -488,7 +494,7 @@ def _load_and_prepare_datasets( cfg: Configuration object. split: Dataset split to load ('train' or 'test'). processor: Optional processor for multimodal datasets. - preprocess_iterable: Whether to use iterable preprocessing. + streaming: Whether to use iterable preprocessing. Returns: Tuple of (train_dataset, eval_dataset, prompters). @@ -499,7 +505,7 @@ def _load_and_prepare_datasets( cfg, split=split, processor=processor, - preprocess_iterable=preprocess_iterable, + streaming=streaming, ) # Apply dataset sharding if configured using shared function diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 1d7d37f15..6b6e0e281 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -236,11 +236,9 @@ def _load_from_local_path( try: return load_from_disk(dataset_config.path) except FileNotFoundError: - load_dataset_kwargs["streaming"] = False return load_dataset(dataset_config.path, **load_dataset_kwargs) elif local_path.is_file(): dataset_type = get_dataset_type(dataset_config) - load_dataset_kwargs["streaming"] = False return load_dataset( dataset_type, data_files=dataset_config.path, diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/streaming.py similarity index 86% rename from src/axolotl/utils/data/pretraining.py rename to src/axolotl/utils/data/streaming.py index 72c5536e9..2cb35ee7c 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/streaming.py @@ -1,4 +1,4 @@ -"""data handling specific to pretraining""" +"""Data handling specific to streaming datasets.""" import functools from collections import defaultdict @@ -17,10 +17,10 @@ from axolotl.utils.trainer import process_pretraining_datasets_for_packing LOG = get_logger(__name__) -def encode_pretraining( +def encode_streaming( + examples: Dict[str, List], tokenizer: PreTrainedTokenizerBase, max_tokens: int, - examples: Dict[str, List], text_column: str = "text", concatenate: bool = True, ) -> Dict[str, List]: @@ -176,45 +176,57 @@ def encode_pretraining( return ret -def wrap_pretraining_dataset( +def wrap_streaming_dataset( dataset, tokenizer, cfg, ds_wrapper_fn, - max_tokens=2048, - batch_size=1, - seed=42, - buffer_size=10_000, ): if cfg.sample_packing: + # For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure + # attention isolation between packed sequences + multipack_attn = ( + True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn + ) + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( tokenizer, return_tensors="pt", padding=True, - pad_to_multiple_of=max_tokens, - multipack_attn=cfg.pretrain_multipack_attn, + pad_to_multiple_of=cfg.sequence_len, + multipack_attn=multipack_attn, ) encode = functools.partial( - encode_packed_pretraining, + encode_packed_streaming, collate_fn, ds_wrapper_fn, - max_seq_length=max_tokens, - batch_size=batch_size, - multipack_attn=cfg.pretrain_multipack_attn, + max_seq_length=cfg.sequence_len, + batch_size=cfg.micro_batch_size, + multipack_attn=multipack_attn, ) - # set this to 1 so downstream data_loader doesn't try to increase the batch again + + # Set this to 1 so downstream data_loader doesn't try to increase the batch size + # again cfg.micro_batch_size = 1 else: + # NOTE: This is not reachable for SFT datasets since we use the pre-existing + # loading function for non-packed streaming datasets. Refer to + # _prepare_streaming_datasets in sft.py for that code path. + text_column = ( + getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text" + ) encode = functools.partial( - encode_pretraining, - tokenizer, - max_tokens, - text_column=cfg.pretraining_dataset[0].text_column or "text", + encode_streaming, + tokenizer=tokenizer, + max_tokens=cfg.sequence_len, + text_column=text_column, concatenate=cfg.pretraining_sample_concatenation is True, ) if cfg.shuffle_merged_datasets: - dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + dataset = dataset.shuffle( + seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size + ) else: LOG.debug("NOT shuffling merged pretraining datasets") @@ -232,14 +244,13 @@ def wrap_pretraining_dataset( dataset = dataset.map( encode, batched=True, - batch_size=buffer_size, - # input_columns="text", + batch_size=cfg.streaming_multipack_buffer_size, remove_columns=remove_columns, ) return dataset -def encode_packed_pretraining( +def encode_packed_streaming( collate_fn, ds_wrapper: Callable, examples: Dict[str, List], @@ -274,8 +285,6 @@ def encode_packed_pretraining( for batch in sampler: for data in batch: features = train_dataset[data] - if "num_truncated_tokens" in features: - del features["num_truncated_tokens"] if "num_truncated_tokens" in features: del features["num_truncated_tokens"] if "overflow_to_sample_mapping" in features: diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 4868576a0..445a65d6c 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -190,12 +190,21 @@ def handle_long_seq_in_dataset( Returns: Filtered dataset with long sequences removed. """ - if "input_ids" not in dataset.column_names: + if ( + hasattr(dataset, "column_names") + and dataset.column_names + and "input_ids" not in dataset.column_names + ): LOG.warning( "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " "expected for reward modeling." ) return dataset + elif not hasattr(dataset, "column_names") or dataset.column_names is None: + LOG.info( + "Dataset is streaming (IterableDataset), skipping long sequence handling" + ) + return dataset drop_long = functools.partial( drop_long_seq, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4b5f571dc..d43c346cd 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -475,12 +475,6 @@ class AxolotlInputConfig( }, ) multipack_real_batches: bool | None = None - pretraining_sample_concatenation: bool | None = Field( - default=None, - json_schema_extra={ - "description": "whether to concatenate samples during pretraining", - }, - ) batch_flattening: Literal["auto"] | bool | None = Field( default=None, @@ -495,13 +489,34 @@ class AxolotlInputConfig( pose_max_context_len: int | None = None pose_num_chunks: int | None = None - pretrain_multipack_buffer_size: int | None = 10_000 + # Deprecated: Use streaming_multipack_buffer_size instead + pretrain_multipack_buffer_size: int | None = Field( + default=None, + deprecated="Deprecated in v0.13.0, will be removed in v0.14.0. Use streaming_multipack_buffer_size instead", + ) pretrain_multipack_attn: bool | None = Field( default=True, json_schema_extra={ "description": "whether to prevent cross attention for packed sequences during pretraining", }, ) + pretraining_sample_concatenation: bool | None = Field( + default=None, + json_schema_extra={ + "description": "whether to concatenate samples during pretraining", + }, + ) + + streaming: bool | None = Field( + default=None, + json_schema_extra={"description": "Use streaming mode for loading datasets"}, + ) + streaming_multipack_buffer_size: int | None = Field( + default=10_000, + json_schema_extra={ + "description": "Buffer size for multipack streaming datasets" + }, + ) xformers_attention: bool | None = Field( default=None, @@ -1264,3 +1279,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): data["dataset_processes"] = get_default_process_count() return data + + @model_validator(mode="before") + @classmethod + def check_deduplication_with_streaming(cls, data): + if data.get("dataset_exact_deduplication") and ( + data.get("streaming") or data.get("pretraining_dataset") + ): + raise NotImplementedError( + "dataset_exact_deduplication is not available for streaming datasets. " + ) + return data diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 791894990..49add8081 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -60,6 +60,20 @@ class DatasetValidationMixin: raise ValueError("either datasets or pretraining_dataset is required") return data + @model_validator(mode="before") + @classmethod + def check_pretraining_streaming_deprecation(cls, data): + # TODO(djsaunde): remove this check + implement change for 0.13.0 release + if data.get("pretraining_dataset") and not data.get("streaming"): + LOG.warning( + "Setting `pretraining_dataset` without explicitly setting `streaming: " + "true` is deprecated. In a future release, streaming will not be " + "automatically enabled when using pretraining_dataset. Please " + "explicitly set `streaming: true` in your configuration to maintain " + "current behavior." + ) + return data + @model_validator(mode="before") @classmethod def check_push_ds_auth(cls, data): @@ -340,6 +354,30 @@ class TrainingValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_multipack_buffer_size(cls, data): + if data.get("pretrain_multipack_buffer_size") and not data.get( + "streaming_multipack_buffer_size" + ): + LOG.warning( + "`pretrain_multipack_buffer_size` is deprecated in v0.13.0, will be " + "removed in v0.14.0. Use `streaming_multipack_buffer_size` instead." + ) + data["streaming_multipack_buffer_size"] = data[ + "pretrain_multipack_buffer_size" + ] + del data["pretrain_multipack_buffer_size"] + elif data.get("pretrain_multipack_buffer_size") and data.get( + "streaming_multipack_buffer_size" + ): + raise ValueError( + "pretrain_multipack_buffer_size is deprecated, use " + "streaming_multipack_buffer_size; both are set, please remove the " + "deprecated pretrain_multipack_buffer_size setting" + ) + return data + @model_validator(mode="after") def check_fft_possible_bad_config(self): if ( @@ -1074,6 +1112,50 @@ class PretrainingValidationMixin: data["accelerator_config"]["dispatch_batches"] = False return data + @model_validator(mode="before") + @classmethod + def check_pretraining_w_val_set_size(cls, data): + if data.get("pretraining_dataset") and data.get("val_set_size"): + raise ValueError( + "val_set_size is not supported with pretraining_dataset. " + "Use test_datasets to specify evaluation datasets for pretraining." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_val_set_size(cls, data): + if data.get("streaming") and data.get("val_set_size"): + raise ValueError( + "val_set_size is not supported with streaming datasets. " + "Use test_datasets to specify evaluation datasets when streaming is enabled." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_max_steps(cls, data): + if data.get("streaming") and not data.get("max_steps"): + raise ValueError( + "max_steps must be set when using streaming datasets. " + "Trainer cannot infer dataset length for iterable datasets." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_streaming_w_multiple_datasets(cls, data): + if ( + data.get("streaming") + and data.get("sample_packing") + and data.get("datasets") + and len(data.get("datasets")) > 1 + ): + raise NotImplementedError( + "Sample packing with multiple streaming datasets is not yet supported" + ) + return data + class ModelCompatibilityValidationMixin: """Validation methods for specific model compatibility.""" diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 98383614b..ff47b9427 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -25,7 +25,7 @@ def min_cfg(temp_dir): "liger_rms_norm": True, "liger_glu_activation": True, "torch_compile": True, - "chat_template": "llama3", + "chat_template": "qwen3", "kd_trainer": True, "kd_ce_alpha": 0.1, "kd_alpha": 0.9, diff --git a/tests/e2e/test_streaming.py b/tests/e2e/test_streaming.py new file mode 100644 index 000000000..5dccf00dd --- /dev/null +++ b/tests/e2e/test_streaming.py @@ -0,0 +1,73 @@ +"""E2E tests for streaming dataset functionality""" + +# pylint: disable=duplicate-code + +import pytest + +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from .utils import check_model_output_exists, check_tensorboard + + +class TestStreamingDatasets: + """Test case for streaming datasets""" + + @pytest.mark.parametrize( + "sample_packing", + [True, False], + ) + def test_streaming_dataset(self, temp_dir, sample_packing): + """Test streaming datasets""" + + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "flash_attention": True, + "sequence_len": 1024, + "sample_packing": sample_packing, + "pretrain_multipack_attn": sample_packing, + "streaming_multipack_buffer_size": 10000, + "dataset_processes": 1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + # Streaming config + "streaming": True, + "max_steps": 3, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "val_set_size": 0.0, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "save_safetensors": True, + "bf16": "auto", + "use_tensorboard": True, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg) + + # Verify training actually happened by checking loss decrease + check_tensorboard( + temp_dir + "/runs", + "train/train_loss", + 3.0, + "Train Loss (%s) is too high", + ) diff --git a/tests/test_data.py b/tests/test_data.py index 6d583cfd3..99ed06336 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,7 +6,7 @@ import unittest from transformers import LlamaTokenizer -from axolotl.utils.data import encode_pretraining, md5 +from axolotl.utils.data import encode_streaming, md5 from tests.hf_offline_utils import enable_hf_offline @@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase): "hello, hello", ] } - result = encode_pretraining(self.tokenizer, self.max_tokens, examples) + result = encode_streaming(examples, self.tokenizer, self.max_tokens) self.assertEqual(len(result["input_ids"]), 3) diff --git a/tests/test_packed_dataset.py b/tests/test_packed_dataset.py index 43e4f3d39..64f314e2e 100644 --- a/tests/test_packed_dataset.py +++ b/tests/test_packed_dataset.py @@ -1,16 +1,11 @@ """Module for testing dataset sequence packing""" import unittest -from pathlib import Path -from datasets import Dataset, load_dataset from transformers import AutoTokenizer from axolotl.cli.args import TrainerCliArgs from axolotl.common.datasets import load_datasets -from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset -from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy -from axolotl.prompters import AlpacaPrompter from axolotl.train import setup_model_and_trainer from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -35,43 +30,6 @@ class TestPacking(unittest.TestCase): } ) - def test_increments_attention(self): - prompter = AlpacaPrompter("chat") - strat = AlpacaPromptTokenizingStrategy( - prompter, - self.tokenizer, - False, - 2048, - ) - dateset = load_dataset( - "json", - data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"), - )["train"] - dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset))) - - constant_len_dataset = ConstantLengthDataset( - self.tokenizer, - [dataset], - seq_length=2048, - ) - packed_dataset = Dataset.from_list(list(constant_len_dataset)) - example = packed_dataset[0] - next_bos_index = ( - example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1 - ) # add one since we sliced - - # first example doesn't have mask reset - assert example["input_ids"][0] == self.tokenizer.bos_token_id - assert example["attention_mask"][0] == 1 - assert example["position_ids"][0] == 0 - assert example["position_ids"][1] == 1 - - # but subsequent one does - assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id - assert example["attention_mask"][next_bos_index] == 2 - assert example["position_ids"][next_bos_index] == 0 - assert example["position_ids"][next_bos_index + 1] == 1 - @with_temp_dir def test_lora_packing(self, temp_dir): cfg = DictDefault( diff --git a/tests/test_packed_pretraining.py b/tests/test_packed_pretraining.py index 117bc0dbd..0458f7ba2 100644 --- a/tests/test_packed_pretraining.py +++ b/tests/test_packed_pretraining.py @@ -9,7 +9,7 @@ import torch from datasets import IterableDataset from torch.utils.data import DataLoader -from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset +from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset from axolotl.utils.dict import DictDefault @@ -77,14 +77,11 @@ class TestPretrainingPacking: ) original_bsz = cfg.micro_batch_size - train_dataset = wrap_pretraining_dataset( + train_dataset = wrap_streaming_dataset( dataset, tokenizer_huggyllama, cfg, ds_wrapper_partial, - max_tokens=cfg.sequence_len, - batch_size=cfg.micro_batch_size, - seed=cfg.seed or 42, ) trainer_loader = DataLoader( diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 000000000..54acbb5e4 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,238 @@ +"""Test streaming configuration and data loading functionality.""" + +import unittest +from unittest.mock import Mock, patch + +from datasets import IterableDataset + +from axolotl.utils.dict import DictDefault +from axolotl.utils.data.sft import ( + _prepare_streaming_dataset, + prepare_datasets, +) +from axolotl.utils.config import validate_config + + +class TestStreamingConfig(unittest.TestCase): + """Test streaming configuration and deprecation handling.""" + + def test_streaming_multipack_buffer_size_deprecation(self): + """Test that pretrain_multipack_buffer_size is properly deprecated.""" + # Test with old config name + cfg_old = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "pretrain_multipack_buffer_size": 5000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm: + validated_cfg = validate_config(cfg_old) + self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0]) + + self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000) + self.assertIsNone( + getattr(validated_cfg, "pretrain_multipack_buffer_size", None) + ) + + def test_streaming_multipack_buffer_size_new(self): + """Test that new streaming_multipack_buffer_size works correctly.""" + cfg_new = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "streaming_multipack_buffer_size": 7000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + validated_cfg = validate_config(cfg_new) + self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000) + + def test_both_buffer_sizes_raises_error(self): + """Test that having both old and new buffer size configs raises an error.""" + cfg_both = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "pretrain_multipack_buffer_size": 5000, + "streaming_multipack_buffer_size": 7000, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 0.0001, + } + ) + + with self.assertRaises(ValueError) as cm: + validate_config(cfg_both) + self.assertIn("both are set", str(cm.exception)) + + +class TestStreamingDatasetPreparation(unittest.TestCase): + """Test dataset preparation with streaming configuration.""" + + def setUp(self): + self.tokenizer = Mock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + + @patch("axolotl.utils.data.sft._prepare_streaming_dataset") + def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming): + """Test that streaming=True triggers streaming dataset preparation.""" + cfg = DictDefault( + { + "streaming": True, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + } + ) + + mock_prepare_streaming.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None) + + @patch("axolotl.utils.data.sft._prepare_streaming_dataset") + def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming): + """Test that pretraining_dataset triggers streaming dataset preparation.""" + cfg = DictDefault( + { + "pretraining_dataset": "test/dataset", + } + ) + + mock_prepare_streaming.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None) + + @patch("axolotl.utils.data.sft._prepare_standard_dataset") + def test_prepare_datasets_without_streaming(self, mock_prepare_standard): + """Test that without streaming, standard dataset preparation is used.""" + cfg = DictDefault( + { + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + } + ) + + mock_prepare_standard.return_value = (Mock(), None, 100, []) + + prepare_datasets(cfg, self.tokenizer) + + mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None) + + +class TestStreamingWithSamplePacking(unittest.TestCase): + """Test streaming dataset preparation with sample packing.""" + + def setUp(self): + self.tokenizer = Mock() + self.tokenizer.pad_token_id = 0 + self.tokenizer.eos_token_id = 1 + + @patch("axolotl.utils.data.sft._load_streaming_dataset") + def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming): + """Test that streaming SFT with sample_packing sets default split.""" + cfg = DictDefault( + { + "streaming": True, + "sample_packing": True, + "datasets": [{"path": "test/dataset", "type": "alpaca"}], + "sequence_len": 256, + "micro_batch_size": 1, + } + ) + + mock_load_streaming.return_value = Mock(spec=IterableDataset) + + with patch("axolotl.utils.data.sft._load_and_prepare_datasets"): + _prepare_streaming_dataset(cfg, self.tokenizer, None) + + # Check that the dataset config has split set to 'train' + call_args = mock_load_streaming.call_args + dataset_config = call_args[0][0] + self.assertEqual(dataset_config.split, "train") + + def test_multipack_attn_forced_true_for_sft(self): + """Test that multipack_attn is forced to True for SFT with sample packing.""" + from axolotl.utils.data.streaming import wrap_streaming_dataset + + cfg = DictDefault( + { + "sample_packing": True, + "pretrain_multipack_attn": False, # Should be overridden for SFT + "pretraining_dataset": None, # This makes it SFT + "sequence_len": 256, + "micro_batch_size": 1, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + mock_dataset = Mock() + mock_dataset.features = None # For streaming datasets + mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator + mock_dataset.map = Mock(return_value=mock_dataset) + mock_ds_wrapper = Mock() + + with patch( + "axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq" + ) as mock_collator: + with patch("axolotl.utils.data.streaming.encode_packed_streaming"): + wrap_streaming_dataset( + mock_dataset, self.tokenizer, cfg, mock_ds_wrapper + ) + + # Check that multipack_attn=True was used in the collator + mock_collator.assert_called_once() + call_kwargs = mock_collator.call_args[1] + self.assertTrue(call_kwargs["multipack_attn"]) + + def test_multipack_attn_respects_config_for_pretraining(self): + """Test that multipack_attn respects config for pretraining datasets.""" + from axolotl.utils.data.streaming import wrap_streaming_dataset + + cfg = DictDefault( + { + "sample_packing": True, + "pretrain_multipack_attn": False, # Should be respected for pretraining + "pretraining_dataset": "test/dataset", # This makes it pretraining + "sequence_len": 256, + "micro_batch_size": 1, + "streaming_multipack_buffer_size": 1000, + "seed": 42, + } + ) + + mock_dataset = Mock() + mock_dataset.features = None # For streaming datasets + mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator + mock_dataset.map = Mock(return_value=mock_dataset) + mock_ds_wrapper = Mock() + + with patch( + "axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq" + ) as mock_collator: + with patch("axolotl.utils.data.streaming.encode_packed_streaming"): + wrap_streaming_dataset( + mock_dataset, self.tokenizer, cfg, mock_ds_wrapper + ) + + # Check that multipack_attn=False was used (respecting config) + mock_collator.assert_called_once() + call_kwargs = mock_collator.call_args[1] + self.assertFalse(call_kwargs["multipack_attn"]) + + +if __name__ == "__main__": + unittest.main()