initial impl of streaming preprocessing
This commit is contained in:
@@ -40,6 +40,7 @@ from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
StreamingDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
@@ -422,6 +423,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
is_eval=False,
|
||||
**kwargs,
|
||||
):
|
||||
from datasets import IterableDataset
|
||||
|
||||
if isinstance(self.train_dataset, IterableDataset) and not is_eval:
|
||||
LOG.info("Using StreamingDataCollator")
|
||||
return StreamingDataCollator(
|
||||
tokenizer=self.tokenizer,
|
||||
cfg=self.cfg,
|
||||
prompter=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
"""
|
||||
shared axolotl collators for multipack, mamba, multimodal
|
||||
"""
|
||||
"""Shared axolotl collators for multipack, mamba, multimodal, etc."""
|
||||
|
||||
from .batching import ( # noqa: F401
|
||||
from .batching import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from .mamba import MambaDataCollator # noqa: F401
|
||||
from .mamba import MambaDataCollator
|
||||
from .streaming import StreamingDataCollator
|
||||
|
||||
__all__ = [
|
||||
"BatchSamplerDataCollatorForSeq2Seq",
|
||||
"DataCollatorForSeq2Seq",
|
||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
||||
"MambaDataCollator",
|
||||
"StreamingDataCollator",
|
||||
]
|
||||
|
||||
146
src/axolotl/utils/collators/streaming.py
Normal file
146
src/axolotl/utils/collators/streaming.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase, default_data_collator
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingDataCollator:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
cfg: DictDefault
|
||||
prompter: Prompter | None = None
|
||||
padding: bool | str | PaddingStrategy = True
|
||||
max_length: int | None = None
|
||||
pad_to_multiple_of: int | None = None
|
||||
label_pad_token_id: int = -100
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.max_length is None:
|
||||
self.max_length = self.cfg.sequence_len
|
||||
|
||||
def __call__(self, raw_batch: List[dict]) -> dict[str, Any]:
|
||||
processed_samples = []
|
||||
|
||||
for raw_sample in raw_batch:
|
||||
formatted_sample = raw_sample
|
||||
if self.prompter:
|
||||
formatted_sample = self._apply_prompt_formatting(raw_sample)
|
||||
|
||||
tokenized_sample = self._tokenize_sample(formatted_sample)
|
||||
|
||||
if len(tokenized_sample["input_ids"]) > self.max_length:
|
||||
tokenized_sample = self._truncate_sample(tokenized_sample)
|
||||
|
||||
if tokenized_sample.get("input_ids"):
|
||||
processed_samples.append(tokenized_sample)
|
||||
|
||||
return self._pad_and_batch(processed_samples)
|
||||
|
||||
def _apply_prompt_formatting(self, raw_sample: dict) -> dict:
|
||||
formatted_text = self.prompter.build_prompt(
|
||||
instruction=raw_sample.get("instruction", ""),
|
||||
input=raw_sample.get("input", ""),
|
||||
output=raw_sample.get("output", ""),
|
||||
)
|
||||
return {"text": formatted_text}
|
||||
|
||||
def _tokenize_sample(self, sample: dict) -> dict:
|
||||
text = sample.get("text", sample.get("content", ""))
|
||||
|
||||
if not text:
|
||||
instruction = sample.get("instruction", "")
|
||||
input_text = sample.get("input", "")
|
||||
output_text = sample.get("output", "")
|
||||
|
||||
parts = []
|
||||
if instruction:
|
||||
parts.append(f"Instruction: {instruction}")
|
||||
if input_text:
|
||||
parts.append(f"Input: {input_text}")
|
||||
if output_text:
|
||||
parts.append(f"Output: {output_text}")
|
||||
text = "\n".join(parts)
|
||||
|
||||
if not text:
|
||||
return {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
tokenized = self.tokenizer(
|
||||
text,
|
||||
truncation=False,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
|
||||
tokenized["labels"] = tokenized["input_ids"].copy()
|
||||
return tokenized
|
||||
|
||||
def _truncate_sample(self, tokenized_sample: dict) -> dict:
|
||||
max_len = self.max_length
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
if key in tokenized_sample:
|
||||
tokenized_sample[key] = tokenized_sample[key][:max_len]
|
||||
return tokenized_sample
|
||||
|
||||
def _pad_and_batch(self, processed_samples: List[dict]) -> dict[str, Any]:
|
||||
if not processed_samples:
|
||||
processed_samples = [
|
||||
{
|
||||
"input_ids": [self.tokenizer.eos_token_id],
|
||||
"attention_mask": [1],
|
||||
"labels": [self.tokenizer.eos_token_id],
|
||||
}
|
||||
]
|
||||
|
||||
batch_samples = []
|
||||
for sample in processed_samples:
|
||||
batch_sample = {}
|
||||
for key, value in sample.items():
|
||||
if key in ["input_ids", "attention_mask", "labels"]:
|
||||
batch_sample[key] = torch.tensor(value, dtype=torch.long)
|
||||
batch_samples.append(batch_sample)
|
||||
|
||||
if self.padding:
|
||||
max_len_in_batch = max(len(sample["input_ids"]) for sample in batch_samples)
|
||||
|
||||
for sample in batch_samples:
|
||||
current_len = len(sample["input_ids"])
|
||||
pad_len = max_len_in_batch - current_len
|
||||
|
||||
if pad_len > 0:
|
||||
pad_token_id = (
|
||||
self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
sample["input_ids"] = torch.cat(
|
||||
[
|
||||
sample["input_ids"],
|
||||
torch.full((pad_len,), pad_token_id, dtype=torch.long),
|
||||
]
|
||||
)
|
||||
sample["attention_mask"] = torch.cat(
|
||||
[
|
||||
sample["attention_mask"],
|
||||
torch.zeros(pad_len, dtype=torch.long),
|
||||
]
|
||||
)
|
||||
sample["labels"] = torch.cat(
|
||||
[
|
||||
sample["labels"],
|
||||
torch.full(
|
||||
(pad_len,), self.label_pad_token_id, dtype=torch.long
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
batch = {}
|
||||
for key in ["input_ids", "attention_mask", "labels"]:
|
||||
if key in batch_samples[0]:
|
||||
batch[key] = torch.stack([sample[key] for sample in batch_samples])
|
||||
|
||||
return batch
|
||||
@@ -74,18 +74,52 @@ def prepare_datasets(
|
||||
Returns:
|
||||
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
|
||||
"""
|
||||
# Determine streaming mode from config
|
||||
streaming_mode = _determine_streaming_mode(cfg)
|
||||
|
||||
# Override preprocess_iterable parameter with streaming config
|
||||
if streaming_mode:
|
||||
preprocess_iterable = True
|
||||
if cfg.pretraining_dataset:
|
||||
return _prepare_streaming_pretraining_dataset(cfg, tokenizer, processor)
|
||||
else:
|
||||
return _prepare_streaming_sft_dataset(cfg, tokenizer, processor)
|
||||
else:
|
||||
if cfg.pretraining_dataset:
|
||||
return _prepare_pretraining_dataset(
|
||||
cfg, tokenizer, processor, preprocess_iterable=False
|
||||
)
|
||||
else:
|
||||
return _prepare_standard_dataset(
|
||||
cfg, tokenizer, processor, preprocess_iterable=False
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset:
|
||||
return _prepare_pretraining_dataset(
|
||||
cfg, tokenizer, processor, preprocess_iterable
|
||||
|
||||
def _prepare_streaming_sft_dataset(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
processor: ProcessorMixin | None,
|
||||
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
||||
LOG.info("Loading streaming datasets")
|
||||
|
||||
raw_datasets = _load_raw_datasets_for_streaming(cfg, split="train")
|
||||
|
||||
eval_dataset = None
|
||||
if cfg.test_datasets:
|
||||
eval_raw_datasets = _load_raw_datasets_for_streaming(
|
||||
cfg, split="test", dataset_configs=cfg.test_datasets
|
||||
)
|
||||
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
|
||||
eval_dataset = _process_eval_dataset_minimal(
|
||||
eval_raw_datasets, cfg, tokenizer, processor
|
||||
)
|
||||
elif cfg.val_set_size:
|
||||
LOG.info("Validation splits not supported for streaming datasets")
|
||||
|
||||
if not cfg.max_steps:
|
||||
raise ValueError("max_steps must be set when using streaming datasets")
|
||||
|
||||
total_num_steps = cfg.max_steps
|
||||
LOG.info(f"Maximum steps: {total_num_steps}")
|
||||
|
||||
prompters = [None] * len(cfg.datasets) if cfg.datasets else []
|
||||
return raw_datasets, eval_dataset, total_num_steps, prompters
|
||||
|
||||
|
||||
def _prepare_standard_dataset(
|
||||
@@ -138,19 +172,12 @@ def _prepare_standard_dataset(
|
||||
)
|
||||
|
||||
# Calculate total number of training steps
|
||||
if isinstance(train_dataset, IterableDataset):
|
||||
# For streaming datasets, we must use max_steps
|
||||
if not cfg.max_steps:
|
||||
raise ValueError("max_steps must be set when using streaming datasets")
|
||||
total_num_steps = cfg.max_steps
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
else:
|
||||
# For regular datasets, calculate from dataset size or use max_steps
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
)
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
|
||||
@@ -445,17 +472,14 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
|
||||
|
||||
|
||||
def _handle_train_dataset_split(
|
||||
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
|
||||
dataset: Dataset, cfg: DictDefault
|
||||
) -> tuple[Dataset, Dataset | None]:
|
||||
"""Handle processing for train split, including validation set creation."""
|
||||
val_set_size = (
|
||||
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
||||
)
|
||||
|
||||
if val_set_size:
|
||||
if isinstance(dataset, IterableDataset):
|
||||
LOG.info("Validation splits not supported for streaming datasets, skipping")
|
||||
return dataset, None
|
||||
# Create train/validation split
|
||||
train_dataset, eval_dataset = create_train_validation_split(
|
||||
dataset, cfg, val_set_size
|
||||
@@ -463,33 +487,27 @@ def _handle_train_dataset_split(
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
# No validation split - apply deduplication if needed and return as train dataset
|
||||
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
else:
|
||||
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
|
||||
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
|
||||
train_dataset = dataset
|
||||
|
||||
return train_dataset, None
|
||||
|
||||
|
||||
def _handle_test_dataset_split(
|
||||
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||
) -> tuple[None, Dataset | IterableDataset | None]:
|
||||
dataset: Dataset, cfg: DictDefault
|
||||
) -> tuple[None, Dataset | None]:
|
||||
"""Handle processing for test split."""
|
||||
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
|
||||
if cfg.dataset_exact_deduplication:
|
||||
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
|
||||
else:
|
||||
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
|
||||
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
|
||||
eval_dataset = dataset
|
||||
|
||||
return None, eval_dataset
|
||||
|
||||
|
||||
def _apply_dataset_sharding(
|
||||
dataset: Dataset | IterableDataset, cfg: DictDefault
|
||||
) -> Dataset | IterableDataset:
|
||||
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
||||
"""Apply dataset sharding if configured.
|
||||
|
||||
Args:
|
||||
@@ -548,3 +566,78 @@ def _load_and_prepare_datasets(
|
||||
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
|
||||
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
def _load_raw_datasets_for_streaming(
|
||||
cfg: DictDefault, split: str = "train", dataset_configs: list | None = None
|
||||
) -> IterableDataset:
|
||||
configs = (
|
||||
dataset_configs
|
||||
if dataset_configs is not None
|
||||
else (cfg.datasets if split == "train" else cfg.test_datasets)
|
||||
)
|
||||
|
||||
if not configs:
|
||||
raise ValueError(f"No dataset configurations found for split '{split}'")
|
||||
|
||||
datasets = []
|
||||
for dataset_config in datasets_with_name_generator(configs):
|
||||
raw_dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=True
|
||||
)
|
||||
|
||||
if isinstance(raw_dataset, (DatasetDict, IterableDatasetDict)):
|
||||
if dataset_config.split and dataset_config.split in raw_dataset:
|
||||
raw_dataset = raw_dataset[dataset_config.split]
|
||||
elif split in raw_dataset:
|
||||
raw_dataset = raw_dataset[split]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"no {split} split found for dataset {dataset_config.path}, "
|
||||
"you may specify a split with 'split: ...'"
|
||||
)
|
||||
|
||||
datasets.append(raw_dataset)
|
||||
|
||||
if len(datasets) == 1:
|
||||
return datasets[0]
|
||||
else:
|
||||
return merge_datasets(datasets, cfg)
|
||||
|
||||
|
||||
def _process_eval_dataset_minimal(
|
||||
raw_dataset: IterableDataset,
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
processor: ProcessorMixin | None,
|
||||
) -> Dataset | None:
|
||||
LOG.info("Eval dataset processing skipped for streaming")
|
||||
return None
|
||||
|
||||
|
||||
def _prepare_streaming_pretraining_dataset(
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
processor: ProcessorMixin | None,
|
||||
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
||||
pretraining_config = _extract_pretraining_config(cfg)
|
||||
|
||||
train_dataset = load_dataset_with_config(
|
||||
pretraining_config, cfg.hf_use_auth_token, streaming=True
|
||||
)
|
||||
|
||||
if isinstance(train_dataset, (DatasetDict, IterableDatasetDict)):
|
||||
if pretraining_config.split and pretraining_config.split in train_dataset:
|
||||
train_dataset = train_dataset[pretraining_config.split]
|
||||
elif "train" in train_dataset:
|
||||
train_dataset = train_dataset["train"]
|
||||
else:
|
||||
raise ValueError("no train split found for pretraining dataset")
|
||||
|
||||
if not cfg.max_steps:
|
||||
raise ValueError("max_steps must be set when using streaming datasets")
|
||||
|
||||
total_num_steps = cfg.max_steps
|
||||
LOG.info(f"Maximum steps: {total_num_steps}")
|
||||
|
||||
return train_dataset, None, total_num_steps, []
|
||||
|
||||
Reference in New Issue
Block a user