broken
This commit is contained in:
@@ -27,7 +27,6 @@ from axolotl.utils.data.shared import (
|
|||||||
try_load_from_hub,
|
try_load_from_hub,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
|
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
|
||||||
from axolotl.utils.data.streaming_sft import wrap_streaming_sft_dataset_optimized
|
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
handle_long_seq_in_dataset,
|
handle_long_seq_in_dataset,
|
||||||
@@ -423,18 +422,18 @@ def _load_and_process_single_dataset(
|
|||||||
)
|
)
|
||||||
return wrapped_dataset, prompter
|
return wrapped_dataset, prompter
|
||||||
|
|
||||||
# Use optimized streaming wrapper to avoid repeated preprocessing logs
|
# Use pretraining wrapper for efficient streaming SFT with packing
|
||||||
dataset_wrapper = wrap_streaming_sft_dataset_optimized(
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
|
|
||||||
|
dataset_wrapper = wrap_pretraining_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg,
|
cfg,
|
||||||
ds_wrapper_fn,
|
ds_wrapper_fn,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
batch_size=max(
|
batch_size=cfg.micro_batch_size,
|
||||||
1, cfg.sequence_len // 512
|
seed=cfg.seed,
|
||||||
), # Estimate sequences per pack
|
buffer_size=cfg.pretrain_multipack_buffer_size,
|
||||||
seed=cfg.seed or 42,
|
|
||||||
buffer_size=cfg.pretrain_multipack_buffer_size or 1_000,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use regular streaming wrapper
|
# Use regular streaming wrapper
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
"""Optimized streaming SFT with multipack support that avoids repeated preprocessing."""
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
|
||||||
|
|
||||||
from datasets import Dataset, IterableDataset
|
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|
||||||
from axolotl.utils.trainer import add_position_ids, drop_long_seq
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_packed_streaming_sft(
|
|
||||||
collate_fn,
|
|
||||||
ds_wrapper: Callable,
|
|
||||||
examples: Dict[str, List],
|
|
||||||
max_seq_length: int = 2048,
|
|
||||||
batch_size: int = 4,
|
|
||||||
multipack_attn: Optional[bool] = True,
|
|
||||||
) -> Dict[str, List]:
|
|
||||||
"""
|
|
||||||
Encode streaming SFT data with packing, avoiding repeated preprocessing logs.
|
|
||||||
|
|
||||||
This is similar to encode_packed_pretraining but skips the verbose
|
|
||||||
process_pretraining_datasets_for_packing call that logs repeatedly.
|
|
||||||
"""
|
|
||||||
# Tokenize all the examples
|
|
||||||
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
|
|
||||||
|
|
||||||
# Apply filtering and preprocessing directly without verbose logging
|
|
||||||
# Filter out long sequences
|
|
||||||
def should_keep(sample):
|
|
||||||
return drop_long_seq(sample, sequence_len=max_seq_length)
|
|
||||||
|
|
||||||
# Convert to list for filtering (since we need to iterate anyway)
|
|
||||||
filtered_samples = []
|
|
||||||
for i in range(len(train_dataset)):
|
|
||||||
sample = train_dataset[i]
|
|
||||||
if should_keep(sample):
|
|
||||||
# Add position_ids if needed
|
|
||||||
if not multipack_attn: # skip_position_ids=False when multipack_attn=True
|
|
||||||
sample = add_position_ids(sample)
|
|
||||||
filtered_samples.append(sample)
|
|
||||||
|
|
||||||
# Convert back to dataset
|
|
||||||
if not filtered_samples:
|
|
||||||
return {"input_ids": [], "labels": [], "attention_mask": []}
|
|
||||||
|
|
||||||
train_dataset = Dataset.from_list(filtered_samples)
|
|
||||||
|
|
||||||
# Remove attention_mask if needed for multipack
|
|
||||||
if multipack_attn and "attention_mask" in train_dataset.column_names:
|
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
|
||||||
|
|
||||||
# Use MultipackBatchSampler to create efficient packed batches
|
|
||||||
sampler = MultipackBatchSampler(
|
|
||||||
sampler=RandomSampler(train_dataset),
|
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
|
||||||
batch_size=1,
|
|
||||||
batch_max_len=batch_size * max_seq_length,
|
|
||||||
drop_last=True,
|
|
||||||
num_processes=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect packed data
|
|
||||||
chunked_data = defaultdict(list)
|
|
||||||
|
|
||||||
for batch in sampler:
|
|
||||||
for data in batch:
|
|
||||||
features = train_dataset[data]
|
|
||||||
# Clean up unnecessary fields
|
|
||||||
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
|
|
||||||
if field in features:
|
|
||||||
del features[field]
|
|
||||||
# Ensure labels exist
|
|
||||||
if "labels" not in features:
|
|
||||||
features["labels"] = features["input_ids"].copy()
|
|
||||||
# Apply collator
|
|
||||||
collated_features = collate_fn(features)
|
|
||||||
|
|
||||||
# Collect features
|
|
||||||
for feature in features.keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
|
||||||
|
|
||||||
return chunked_data
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_streaming_sft_dataset_optimized(
|
|
||||||
dataset: IterableDataset,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
cfg,
|
|
||||||
ds_wrapper_fn: Callable,
|
|
||||||
max_tokens: int = 2048,
|
|
||||||
batch_size: int = 4,
|
|
||||||
seed: int = 42,
|
|
||||||
buffer_size: int = 1000,
|
|
||||||
) -> IterableDataset:
|
|
||||||
"""
|
|
||||||
Wrap a streaming SFT dataset with optimized multipack batching.
|
|
||||||
|
|
||||||
This avoids the repeated preprocessing logs by using a custom encoder
|
|
||||||
that applies filtering and position_ids directly.
|
|
||||||
"""
|
|
||||||
# Create collator for packing
|
|
||||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
pad_to_multiple_of=max_tokens,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create optimized encode function
|
|
||||||
encode = functools.partial(
|
|
||||||
encode_packed_streaming_sft,
|
|
||||||
collate_fn,
|
|
||||||
ds_wrapper_fn,
|
|
||||||
max_seq_length=max_tokens,
|
|
||||||
batch_size=batch_size,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply shuffling if configured
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
||||||
else:
|
|
||||||
LOG.debug("NOT shuffling merged streaming datasets")
|
|
||||||
|
|
||||||
# Get column names to remove
|
|
||||||
remove_columns = []
|
|
||||||
for first_row in dataset:
|
|
||||||
remove_columns = list(first_row.keys())
|
|
||||||
break
|
|
||||||
|
|
||||||
# Reset dataset after peeking
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
||||||
|
|
||||||
# Map the optimized encoding function
|
|
||||||
dataset = dataset.map(
|
|
||||||
encode,
|
|
||||||
batched=True,
|
|
||||||
batch_size=buffer_size,
|
|
||||||
remove_columns=remove_columns,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set micro_batch_size to 1 since we've already packed
|
|
||||||
cfg.micro_batch_size = 1
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
Reference in New Issue
Block a user