This commit is contained in:
Dan Saunders
2025-08-25 16:51:36 +00:00
parent 2e2302aae3
commit d3bea3a2eb
2 changed files with 7 additions and 166 deletions

View File

@@ -27,7 +27,6 @@ from axolotl.utils.data.shared import (
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
from axolotl.utils.data.streaming_sft import wrap_streaming_sft_dataset_optimized
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
@@ -423,18 +422,18 @@ def _load_and_process_single_dataset(
)
return wrapped_dataset, prompter
# Use optimized streaming wrapper to avoid repeated preprocessing logs
dataset_wrapper = wrap_streaming_sft_dataset_optimized(
# Use pretraining wrapper for efficient streaming SFT with packing
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
dataset_wrapper = wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=cfg.sequence_len,
batch_size=max(
1, cfg.sequence_len // 512
), # Estimate sequences per pack
seed=cfg.seed or 42,
buffer_size=cfg.pretrain_multipack_buffer_size or 1_000,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size,
)
else:
# Use regular streaming wrapper

View File

@@ -1,158 +0,0 @@
"""Optimized streaming SFT with multipack support that avoids repeated preprocessing."""
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import add_position_ids, drop_long_seq
LOG = get_logger(__name__)
def encode_packed_streaming_sft(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
"""
Encode streaming SFT data with packing, avoiding repeated preprocessing logs.
This is similar to encode_packed_pretraining but skips the verbose
process_pretraining_datasets_for_packing call that logs repeatedly.
"""
# Tokenize all the examples
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
# Apply filtering and preprocessing directly without verbose logging
# Filter out long sequences
def should_keep(sample):
return drop_long_seq(sample, sequence_len=max_seq_length)
# Convert to list for filtering (since we need to iterate anyway)
filtered_samples = []
for i in range(len(train_dataset)):
sample = train_dataset[i]
if should_keep(sample):
# Add position_ids if needed
if not multipack_attn: # skip_position_ids=False when multipack_attn=True
sample = add_position_ids(sample)
filtered_samples.append(sample)
# Convert back to dataset
if not filtered_samples:
return {"input_ids": [], "labels": [], "attention_mask": []}
train_dataset = Dataset.from_list(filtered_samples)
# Remove attention_mask if needed for multipack
if multipack_attn and "attention_mask" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("attention_mask")
# Use MultipackBatchSampler to create efficient packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
)
# Collect packed data
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
# Clean up unnecessary fields
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
if field in features:
del features[field]
# Ensure labels exist
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
# Apply collator
collated_features = collate_fn(features)
# Collect features
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data
def wrap_streaming_sft_dataset_optimized(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
ds_wrapper_fn: Callable,
max_tokens: int = 2048,
batch_size: int = 4,
seed: int = 42,
buffer_size: int = 1000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with optimized multipack batching.
This avoids the repeated preprocessing logs by using a custom encoder
that applies filtering and position_ids directly.
"""
# Create collator for packing
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Create optimized encode function
encode = functools.partial(
encode_packed_streaming_sft,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged streaming datasets")
# Get column names to remove
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
# Map the optimized encoding function
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set micro_batch_size to 1 since we've already packed
cfg.micro_batch_size = 1
return dataset