Compare commits

..

5 Commits

Author SHA1 Message Date
Dan Saunders
4870638734 initial impl of streaming preprocessing 2025-08-19 23:10:54 +00:00
Dan Saunders
b25078397c nit 2025-08-19 18:12:09 +00:00
Dan Saunders
ba681125d7 separate streaming and pretraining 2025-08-19 18:05:05 +00:00
VED
c10eb811fa data_parallel_size in in VllmserveCliArgs (#3074)
* data_parallel_size in in VllmserveCliArgs

* moved to 43
2025-08-18 08:44:37 -04:00
VED
0eef385b1a [feat] truncation support with excess_length_strategy (#3068) [skip ci]
* feat:truncation support with excess_len

* pre-commit

* excess_length_strategy

* requested changes

* lint

* added handle_long_seq_in_dataset in sft

* comments improved
2025-08-18 08:39:13 -04:00
12 changed files with 460 additions and 46 deletions

View File

@@ -40,6 +40,12 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -40,6 +40,7 @@ from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
StreamingDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
@@ -422,9 +423,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
is_eval=False,
**kwargs,
):
from datasets import IterableDataset
if isinstance(self.train_dataset, IterableDataset) and not is_eval:
LOG.info("Using StreamingDataCollator")
return StreamingDataCollator(
tokenizer=self.tokenizer,
cfg=self.cfg,
prompter=None,
**kwargs,
)
if training_args.pretraining:
if (
not self.cfg.pretraining_sample_concatenation
self.cfg.pretraining_sample_concatenation is False
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)

View File

@@ -272,20 +272,6 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)

View File

@@ -43,7 +43,11 @@ class TokenizedPromptDataset(Dataset):
)
def process(self, dataset):
features = dataset.features.keys()
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
if hasattr(dataset, "features") and dataset.features:
features = dataset.features.keys()
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -54,18 +58,29 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs = {"desc": "Strategy Filtering Rows"}
# Only add num_proc for regular datasets
if features is not None:
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
**filter_kwargs,
)
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if features is not None:
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)

View File

@@ -1,11 +1,19 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
"""Shared axolotl collators for multipack, mamba, multimodal, etc."""
from .batching import ( # noqa: F401
from .batching import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401
from .mamba import MambaDataCollator
from .streaming import StreamingDataCollator
__all__ = [
"BatchSamplerDataCollatorForSeq2Seq",
"DataCollatorForSeq2Seq",
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
"V2BatchSamplerDataCollatorForSeq2Seq",
"MambaDataCollator",
"StreamingDataCollator",
]

View File

@@ -0,0 +1,146 @@
from dataclasses import dataclass
from typing import Any, List
import torch
from transformers import PreTrainedTokenizerBase, default_data_collator
from transformers.utils import PaddingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.dict import DictDefault
@dataclass
class StreamingDataCollator:
tokenizer: PreTrainedTokenizerBase
cfg: DictDefault
prompter: Prompter | None = None
padding: bool | str | PaddingStrategy = True
max_length: int | None = None
pad_to_multiple_of: int | None = None
label_pad_token_id: int = -100
return_tensors: str = "pt"
def __post_init__(self):
if self.max_length is None:
self.max_length = self.cfg.sequence_len
def __call__(self, raw_batch: List[dict]) -> dict[str, Any]:
processed_samples = []
for raw_sample in raw_batch:
formatted_sample = raw_sample
if self.prompter:
formatted_sample = self._apply_prompt_formatting(raw_sample)
tokenized_sample = self._tokenize_sample(formatted_sample)
if len(tokenized_sample["input_ids"]) > self.max_length:
tokenized_sample = self._truncate_sample(tokenized_sample)
if tokenized_sample.get("input_ids"):
processed_samples.append(tokenized_sample)
return self._pad_and_batch(processed_samples)
def _apply_prompt_formatting(self, raw_sample: dict) -> dict:
formatted_text = self.prompter.build_prompt(
instruction=raw_sample.get("instruction", ""),
input=raw_sample.get("input", ""),
output=raw_sample.get("output", ""),
)
return {"text": formatted_text}
def _tokenize_sample(self, sample: dict) -> dict:
text = sample.get("text", sample.get("content", ""))
if not text:
instruction = sample.get("instruction", "")
input_text = sample.get("input", "")
output_text = sample.get("output", "")
parts = []
if instruction:
parts.append(f"Instruction: {instruction}")
if input_text:
parts.append(f"Input: {input_text}")
if output_text:
parts.append(f"Output: {output_text}")
text = "\n".join(parts)
if not text:
return {"input_ids": [], "attention_mask": [], "labels": []}
tokenized = self.tokenizer(
text,
truncation=False,
padding=False,
return_tensors=None,
)
tokenized["labels"] = tokenized["input_ids"].copy()
return tokenized
def _truncate_sample(self, tokenized_sample: dict) -> dict:
max_len = self.max_length
for key in ["input_ids", "attention_mask", "labels"]:
if key in tokenized_sample:
tokenized_sample[key] = tokenized_sample[key][:max_len]
return tokenized_sample
def _pad_and_batch(self, processed_samples: List[dict]) -> dict[str, Any]:
if not processed_samples:
processed_samples = [
{
"input_ids": [self.tokenizer.eos_token_id],
"attention_mask": [1],
"labels": [self.tokenizer.eos_token_id],
}
]
batch_samples = []
for sample in processed_samples:
batch_sample = {}
for key, value in sample.items():
if key in ["input_ids", "attention_mask", "labels"]:
batch_sample[key] = torch.tensor(value, dtype=torch.long)
batch_samples.append(batch_sample)
if self.padding:
max_len_in_batch = max(len(sample["input_ids"]) for sample in batch_samples)
for sample in batch_samples:
current_len = len(sample["input_ids"])
pad_len = max_len_in_batch - current_len
if pad_len > 0:
pad_token_id = (
self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
)
sample["input_ids"] = torch.cat(
[
sample["input_ids"],
torch.full((pad_len,), pad_token_id, dtype=torch.long),
]
)
sample["attention_mask"] = torch.cat(
[
sample["attention_mask"],
torch.zeros(pad_len, dtype=torch.long),
]
)
sample["labels"] = torch.cat(
[
sample["labels"],
torch.full(
(pad_len,), self.label_pad_token_id, dtype=torch.long
),
]
)
batch = {}
for key in ["input_ids", "attention_mask", "labels"]:
if key in batch_samples[0]:
batch[key] = torch.stack([sample[key] for sample in batch_samples])
return batch

View File

@@ -9,6 +9,7 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -28,7 +29,7 @@ from axolotl.utils.data.shared import (
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
handle_long_seq_in_dataset,
retry_on_request_exceptions,
)
from axolotl.utils.data.wrappers import get_dataset_wrapper
@@ -43,6 +44,18 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _determine_streaming_mode(cfg: DictDefault) -> bool:
"""Determine if we should use streaming mode based on config."""
if cfg.streaming is not None:
return cfg.streaming
# Default to streaming for pretraining datasets
if cfg.pretraining_dataset:
return True
return False
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
@@ -61,11 +74,52 @@ def prepare_datasets(
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
streaming_mode = _determine_streaming_mode(cfg)
if streaming_mode:
if cfg.pretraining_dataset:
return _prepare_streaming_pretraining_dataset(cfg, tokenizer, processor)
else:
return _prepare_streaming_sft_dataset(cfg, tokenizer, processor)
else:
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable=False
)
else:
return _prepare_standard_dataset(
cfg, tokenizer, processor, preprocess_iterable=False
)
def _prepare_streaming_sft_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
LOG.info("Loading streaming datasets")
raw_datasets = _load_raw_datasets_for_streaming(cfg, split="train")
eval_dataset = None
if cfg.test_datasets:
eval_raw_datasets = _load_raw_datasets_for_streaming(
cfg, split="test", dataset_configs=cfg.test_datasets
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
eval_dataset = _process_eval_dataset_minimal(
eval_raw_datasets, cfg, tokenizer, processor
)
elif cfg.val_set_size:
LOG.info("Validation splits not supported for streaming datasets")
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
total_num_steps = cfg.max_steps
LOG.info(f"Maximum steps: {total_num_steps}")
prompters = [None] * len(cfg.datasets) if cfg.datasets else []
return raw_datasets, eval_dataset, total_num_steps, prompters
def _prepare_standard_dataset(
@@ -339,9 +393,9 @@ def _load_raw_datasets(
if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len:
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
@@ -373,7 +427,7 @@ def _load_and_process_single_dataset(
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, DatasetDict):
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -512,3 +566,78 @@ def _load_and_prepare_datasets(
train_dataset, eval_dataset = _handle_test_dataset_split(dataset, cfg)
return train_dataset, eval_dataset, prompters
def _load_raw_datasets_for_streaming(
cfg: DictDefault, split: str = "train", dataset_configs: list | None = None
) -> IterableDataset:
configs = (
dataset_configs
if dataset_configs is not None
else (cfg.datasets if split == "train" else cfg.test_datasets)
)
if not configs:
raise ValueError(f"No dataset configurations found for split '{split}'")
datasets = []
for dataset_config in datasets_with_name_generator(configs):
raw_dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=True
)
if isinstance(raw_dataset, (DatasetDict, IterableDatasetDict)):
if dataset_config.split and dataset_config.split in raw_dataset:
raw_dataset = raw_dataset[dataset_config.split]
elif split in raw_dataset:
raw_dataset = raw_dataset[split]
else:
raise ValueError(
f"no {split} split found for dataset {dataset_config.path}, "
"you may specify a split with 'split: ...'"
)
datasets.append(raw_dataset)
if len(datasets) == 1:
return datasets[0]
else:
return merge_datasets(datasets, cfg)
def _process_eval_dataset_minimal(
raw_dataset: IterableDataset,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> Dataset | None:
LOG.info("Eval dataset processing skipped for streaming")
return None
def _prepare_streaming_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
pretraining_config = _extract_pretraining_config(cfg)
train_dataset = load_dataset_with_config(
pretraining_config, cfg.hf_use_auth_token, streaming=True
)
if isinstance(train_dataset, (DatasetDict, IterableDatasetDict)):
if pretraining_config.split and pretraining_config.split in train_dataset:
train_dataset = train_dataset[pretraining_config.split]
elif "train" in train_dataset:
train_dataset = train_dataset["train"]
else:
raise ValueError("no train split found for pretraining dataset")
if not cfg.max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
total_num_steps = cfg.max_steps
LOG.info(f"Maximum steps: {total_num_steps}")
return train_dataset, None, total_num_steps, []

View File

@@ -148,7 +148,36 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset
def drop_long_seq_in_dataset(
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
results = []
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
@@ -161,12 +190,18 @@ def drop_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
else:
# For IterableDataset, we can't check columns upfront, so skip for streaming
if isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
return dataset
drop_long = functools.partial(
drop_long_seq,
@@ -192,8 +227,21 @@ def drop_long_seq_in_dataset(
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter(
drop_long,
process_fn,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
@@ -201,6 +249,11 @@ def drop_long_seq_in_dataset(
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset

View File

@@ -414,6 +414,12 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={
@@ -926,6 +932,34 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
streaming_dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)."
},
)
streaming_mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'."
},
)
streaming_buffer_per_dataset: int | None = Field(
default=1000,
json_schema_extra={
"description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None

View File

@@ -1337,6 +1337,30 @@ class GRPOVllmValidationMixin:
# pylint: disable=too-many-ancestors
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is explicitly enabled
streaming_enabled = getattr(self, "streaming", None) is True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = (
has_pretraining and getattr(self, "streaming", None) is None
)
# If streaming is enabled (explicitly or by default for pretraining)
if streaming_enabled or streaming_default_for_pretraining:
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
@@ -1347,6 +1371,7 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(