diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e5bc21762..a1fee25e3 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -40,6 +40,7 @@ from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, MambaDataCollator, + StreamingDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator @@ -422,6 +423,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): is_eval=False, **kwargs, ): + from datasets import IterableDataset + + if isinstance(self.train_dataset, IterableDataset) and not is_eval: + LOG.info("Using StreamingDataCollator") + return StreamingDataCollator( + tokenizer=self.tokenizer, + cfg=self.cfg, + prompter=None, + **kwargs, + ) + if training_args.pretraining: if ( self.cfg.pretraining_sample_concatenation is False diff --git a/src/axolotl/utils/collators/__init__.py b/src/axolotl/utils/collators/__init__.py index 8c60f223c..aaed3b085 100644 --- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ -1,11 +1,19 @@ -""" -shared axolotl collators for multipack, mamba, multimodal -""" +"""Shared axolotl collators for multipack, mamba, multimodal, etc.""" -from .batching import ( # noqa: F401 +from .batching import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, PretrainingBatchSamplerDataCollatorForSeq2Seq, V2BatchSamplerDataCollatorForSeq2Seq, ) -from .mamba import MambaDataCollator # noqa: F401 +from .mamba import MambaDataCollator +from .streaming import StreamingDataCollator + +__all__ = [ + "BatchSamplerDataCollatorForSeq2Seq", + "DataCollatorForSeq2Seq", + "PretrainingBatchSamplerDataCollatorForSeq2Seq", + "V2BatchSamplerDataCollatorForSeq2Seq", + "MambaDataCollator", + "StreamingDataCollator", +] diff --git a/src/axolotl/utils/collators/streaming.py b/src/axolotl/utils/collators/streaming.py new file mode 100644 index 000000000..1bd47a67d --- /dev/null +++ b/src/axolotl/utils/collators/streaming.py @@ -0,0 +1,146 @@ +from dataclasses import dataclass +from typing import Any, List + +import torch +from transformers import PreTrainedTokenizerBase, default_data_collator +from transformers.utils import PaddingStrategy + +from axolotl.prompters import Prompter +from axolotl.utils.dict import DictDefault + + +@dataclass +class StreamingDataCollator: + tokenizer: PreTrainedTokenizerBase + cfg: DictDefault + prompter: Prompter | None = None + padding: bool | str | PaddingStrategy = True + max_length: int | None = None + pad_to_multiple_of: int | None = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def __post_init__(self): + if self.max_length is None: + self.max_length = self.cfg.sequence_len + + def __call__(self, raw_batch: List[dict]) -> dict[str, Any]: + processed_samples = [] + + for raw_sample in raw_batch: + formatted_sample = raw_sample + if self.prompter: + formatted_sample = self._apply_prompt_formatting(raw_sample) + + tokenized_sample = self._tokenize_sample(formatted_sample) + + if len(tokenized_sample["input_ids"]) > self.max_length: + tokenized_sample = self._truncate_sample(tokenized_sample) + + if tokenized_sample.get("input_ids"): + processed_samples.append(tokenized_sample) + + return self._pad_and_batch(processed_samples) + + def _apply_prompt_formatting(self, raw_sample: dict) -> dict: + formatted_text = self.prompter.build_prompt( + instruction=raw_sample.get("instruction", ""), + input=raw_sample.get("input", ""), + output=raw_sample.get("output", ""), + ) + return {"text": formatted_text} + + def _tokenize_sample(self, sample: dict) -> dict: + text = sample.get("text", sample.get("content", "")) + + if not text: + instruction = sample.get("instruction", "") + input_text = sample.get("input", "") + output_text = sample.get("output", "") + + parts = [] + if instruction: + parts.append(f"Instruction: {instruction}") + if input_text: + parts.append(f"Input: {input_text}") + if output_text: + parts.append(f"Output: {output_text}") + text = "\n".join(parts) + + if not text: + return {"input_ids": [], "attention_mask": [], "labels": []} + + tokenized = self.tokenizer( + text, + truncation=False, + padding=False, + return_tensors=None, + ) + + tokenized["labels"] = tokenized["input_ids"].copy() + return tokenized + + def _truncate_sample(self, tokenized_sample: dict) -> dict: + max_len = self.max_length + for key in ["input_ids", "attention_mask", "labels"]: + if key in tokenized_sample: + tokenized_sample[key] = tokenized_sample[key][:max_len] + return tokenized_sample + + def _pad_and_batch(self, processed_samples: List[dict]) -> dict[str, Any]: + if not processed_samples: + processed_samples = [ + { + "input_ids": [self.tokenizer.eos_token_id], + "attention_mask": [1], + "labels": [self.tokenizer.eos_token_id], + } + ] + + batch_samples = [] + for sample in processed_samples: + batch_sample = {} + for key, value in sample.items(): + if key in ["input_ids", "attention_mask", "labels"]: + batch_sample[key] = torch.tensor(value, dtype=torch.long) + batch_samples.append(batch_sample) + + if self.padding: + max_len_in_batch = max(len(sample["input_ids"]) for sample in batch_samples) + + for sample in batch_samples: + current_len = len(sample["input_ids"]) + pad_len = max_len_in_batch - current_len + + if pad_len > 0: + pad_token_id = ( + self.tokenizer.pad_token_id or self.tokenizer.eos_token_id + ) + + sample["input_ids"] = torch.cat( + [ + sample["input_ids"], + torch.full((pad_len,), pad_token_id, dtype=torch.long), + ] + ) + sample["attention_mask"] = torch.cat( + [ + sample["attention_mask"], + torch.zeros(pad_len, dtype=torch.long), + ] + ) + sample["labels"] = torch.cat( + [ + sample["labels"], + torch.full( + (pad_len,), self.label_pad_token_id, dtype=torch.long + ), + ] + ) + + batch = {} + for key in ["input_ids", "attention_mask", "labels"]: + if key in batch_samples[0]: + batch[key] = torch.stack([sample[key] for sample in batch_samples]) + + return batch diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 17afba9c2..2cb221973 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -74,18 +74,52 @@ def prepare_datasets( Returns: Tuple of (train_dataset, eval_dataset, total_steps, prompters). """ - # Determine streaming mode from config streaming_mode = _determine_streaming_mode(cfg) - # Override preprocess_iterable parameter with streaming config if streaming_mode: - preprocess_iterable = True + if cfg.pretraining_dataset: + return _prepare_streaming_pretraining_dataset(cfg, tokenizer, processor) + else: + return _prepare_streaming_sft_dataset(cfg, tokenizer, processor) + else: + if cfg.pretraining_dataset: + return _prepare_pretraining_dataset( + cfg, tokenizer, processor, preprocess_iterable=False + ) + else: + return _prepare_standard_dataset( + cfg, tokenizer, processor, preprocess_iterable=False + ) - if cfg.pretraining_dataset: - return _prepare_pretraining_dataset( - cfg, tokenizer, processor, preprocess_iterable + +def _prepare_streaming_sft_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, +) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]: + LOG.info("Loading streaming datasets") + + raw_datasets = _load_raw_datasets_for_streaming(cfg, split="train") + + eval_dataset = None + if cfg.test_datasets: + eval_raw_datasets = _load_raw_datasets_for_streaming( + cfg, split="test", dataset_configs=cfg.test_datasets ) - return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable) + eval_dataset = _process_eval_dataset_minimal( + eval_raw_datasets, cfg, tokenizer, processor + ) + elif cfg.val_set_size: + LOG.info("Validation splits not supported for streaming datasets") + + if not cfg.max_steps: + raise ValueError("max_steps must be set when using streaming datasets") + + total_num_steps = cfg.max_steps + LOG.info(f"Maximum steps: {total_num_steps}") + + prompters = [None] * len(cfg.datasets) if cfg.datasets else [] + return raw_datasets, eval_dataset, total_num_steps, prompters def _prepare_standard_dataset( @@ -138,19 +172,12 @@ def _prepare_standard_dataset( ) # Calculate total number of training steps - if isinstance(train_dataset, IterableDataset): - # For streaming datasets, we must use max_steps - if not cfg.max_steps: - raise ValueError("max_steps must be set when using streaming datasets") - total_num_steps = cfg.max_steps + if cfg.max_steps: + total_num_steps = min( + calculate_total_num_steps(cfg, train_dataset), cfg.max_steps + ) else: - # For regular datasets, calculate from dataset size or use max_steps - if cfg.max_steps: - total_num_steps = min( - calculate_total_num_steps(cfg, train_dataset), cfg.max_steps - ) - else: - total_num_steps = calculate_total_num_steps(cfg, train_dataset) + total_num_steps = calculate_total_num_steps(cfg, train_dataset) LOG.info(f"Maximum number of steps set at {total_num_steps}") return train_dataset, eval_dataset, total_num_steps, prompters @@ -445,17 +472,14 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]: def _handle_train_dataset_split( - dataset: Dataset | IterableDataset, cfg: DictDefault -) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]: + dataset: Dataset, cfg: DictDefault +) -> tuple[Dataset, Dataset | None]: """Handle processing for train split, including validation set creation.""" val_set_size = ( int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size) ) if val_set_size: - if isinstance(dataset, IterableDataset): - LOG.info("Validation splits not supported for streaming datasets, skipping") - return dataset, None # Create train/validation split train_dataset, eval_dataset = create_train_validation_split( dataset, cfg, val_set_size @@ -463,33 +487,27 @@ def _handle_train_dataset_split( return train_dataset, eval_dataset # No validation split - apply deduplication if needed and return as train dataset - if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset): + if cfg.dataset_exact_deduplication: train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) else: - if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset): - LOG.info("Deduplication skipped for streaming datasets (not compatible)") train_dataset = dataset return train_dataset, None def _handle_test_dataset_split( - dataset: Dataset | IterableDataset, cfg: DictDefault -) -> tuple[None, Dataset | IterableDataset | None]: + dataset: Dataset, cfg: DictDefault +) -> tuple[None, Dataset | None]: """Handle processing for test split.""" - if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset): + if cfg.dataset_exact_deduplication: eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset) else: - if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset): - LOG.info("Deduplication skipped for streaming datasets (not compatible)") eval_dataset = dataset return None, eval_dataset -def _apply_dataset_sharding( - dataset: Dataset | IterableDataset, cfg: DictDefault -) -> Dataset | IterableDataset: +def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset: """Apply dataset sharding if configured. Args: @@ -548,3 +566,78 @@ def _load_and_prepare_datasets( train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg) return train_dataset, eval_dataset, prompters + + +def _load_raw_datasets_for_streaming( + cfg: DictDefault, split: str = "train", dataset_configs: list | None = None +) -> IterableDataset: + configs = ( + dataset_configs + if dataset_configs is not None + else (cfg.datasets if split == "train" else cfg.test_datasets) + ) + + if not configs: + raise ValueError(f"No dataset configurations found for split '{split}'") + + datasets = [] + for dataset_config in datasets_with_name_generator(configs): + raw_dataset = load_dataset_with_config( + dataset_config, cfg.hf_use_auth_token, streaming=True + ) + + if isinstance(raw_dataset, (DatasetDict, IterableDatasetDict)): + if dataset_config.split and dataset_config.split in raw_dataset: + raw_dataset = raw_dataset[dataset_config.split] + elif split in raw_dataset: + raw_dataset = raw_dataset[split] + else: + raise ValueError( + f"no {split} split found for dataset {dataset_config.path}, " + "you may specify a split with 'split: ...'" + ) + + datasets.append(raw_dataset) + + if len(datasets) == 1: + return datasets[0] + else: + return merge_datasets(datasets, cfg) + + +def _process_eval_dataset_minimal( + raw_dataset: IterableDataset, + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, +) -> Dataset | None: + LOG.info("Eval dataset processing skipped for streaming") + return None + + +def _prepare_streaming_pretraining_dataset( + cfg: DictDefault, + tokenizer: PreTrainedTokenizer, + processor: ProcessorMixin | None, +) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]: + pretraining_config = _extract_pretraining_config(cfg) + + train_dataset = load_dataset_with_config( + pretraining_config, cfg.hf_use_auth_token, streaming=True + ) + + if isinstance(train_dataset, (DatasetDict, IterableDatasetDict)): + if pretraining_config.split and pretraining_config.split in train_dataset: + train_dataset = train_dataset[pretraining_config.split] + elif "train" in train_dataset: + train_dataset = train_dataset["train"] + else: + raise ValueError("no train split found for pretraining dataset") + + if not cfg.max_steps: + raise ValueError("max_steps must be set when using streaming datasets") + + total_num_steps = cfg.max_steps + LOG.info(f"Maximum steps: {total_num_steps}") + + return train_dataset, None, total_num_steps, []