seems to be working?
This commit is contained in:
61
examples/streaming/streaming-pretrain.yml
Normal file
61
examples/streaming/streaming-pretrain.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Example configuration for streaming pretraining
|
||||||
|
# This demonstrates how to pretrain on large datasets that don't fit in memory
|
||||||
|
|
||||||
|
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
# Required: max_steps for streaming pretraining
|
||||||
|
max_steps: 10000
|
||||||
|
|
||||||
|
# Pretraining dataset configuration
|
||||||
|
# These are automatically streamed
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: allenai/c4
|
||||||
|
name: en
|
||||||
|
type: pretrain
|
||||||
|
# Optional: skip N samples (useful for resuming)
|
||||||
|
# skip: 1000000
|
||||||
|
|
||||||
|
# Can also use multiple pretraining datasets
|
||||||
|
# pretraining_dataset:
|
||||||
|
# - path: allenai/c4
|
||||||
|
# name: en
|
||||||
|
# type: pretrain
|
||||||
|
# - path: HuggingFaceFW/fineweb
|
||||||
|
# type: pretrain
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
|
||||||
|
# Sequence and packing configuration
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pretrain_multipack_attn: true
|
||||||
|
pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
micro_batch_size: 4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 3e-4
|
||||||
|
|
||||||
|
# Memory optimizations
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
gradient_checkpointing: true
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
# Checkpointing and logging
|
||||||
|
output_dir: ./outputs/pretrain-streaming
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
save_total_limit: 3 # Keep only last 3 checkpoints
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
|
||||||
|
# Optional: enable wandb for monitoring
|
||||||
|
# wandb_project: streaming-pretrain
|
||||||
|
# wandb_entity: your-entity
|
||||||
|
# wandb_name: c4-pretrain
|
||||||
52
examples/streaming/streaming-sft.yml
Normal file
52
examples/streaming/streaming-sft.yml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# Example configuration for streaming SFT training
|
||||||
|
# This enables training on datasets larger than memory by streaming them from HuggingFace Hub
|
||||||
|
|
||||||
|
base_model: HuggingFaceTB/SmolLM2-135M
|
||||||
|
|
||||||
|
# Enable streaming mode for datasets
|
||||||
|
streaming: true
|
||||||
|
|
||||||
|
# When using streaming, max_steps is required
|
||||||
|
max_steps: 3 # Just test a few steps
|
||||||
|
|
||||||
|
# Training datasets - these will be streamed
|
||||||
|
# datasets:
|
||||||
|
# - path: tatsu-lab/alpaca
|
||||||
|
# type: alpaca
|
||||||
|
# split: train
|
||||||
|
|
||||||
|
pretraining_dataset:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
split: train
|
||||||
|
|
||||||
|
# Dataset configuration
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true # Enable multipack batching
|
||||||
|
pretrain_multipack_attn: true # Enable multipack attention masking
|
||||||
|
pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT)
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|endoftext|>
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples
|
||||||
|
num_epochs: 1 # With streaming, typically use max_steps instead
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 2e-5
|
||||||
|
|
||||||
|
# Enable efficient training
|
||||||
|
bf16: auto
|
||||||
|
tf32: false
|
||||||
|
gradient_checkpointing: true
|
||||||
|
flash_attention: true # Enable flash attention with multipack
|
||||||
|
|
||||||
|
# Logging and checkpointing
|
||||||
|
logging_steps: 10
|
||||||
|
eval_steps: 100
|
||||||
|
save_steps: 200
|
||||||
|
output_dir: ./outputs/streaming-model
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
warmup_steps: 100
|
||||||
@@ -26,6 +26,8 @@ from axolotl.utils.data.shared import (
|
|||||||
save_preprocessed_dataset,
|
save_preprocessed_dataset,
|
||||||
try_load_from_hub,
|
try_load_from_hub,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
|
||||||
|
from axolotl.utils.data.streaming_sft import wrap_streaming_sft_dataset_optimized
|
||||||
from axolotl.utils.data.utils import (
|
from axolotl.utils.data.utils import (
|
||||||
deduplicate_and_log_datasets,
|
deduplicate_and_log_datasets,
|
||||||
handle_long_seq_in_dataset,
|
handle_long_seq_in_dataset,
|
||||||
@@ -73,7 +75,7 @@ def _prepare_standard_dataset(
|
|||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
processor: ProcessorMixin | None,
|
processor: ProcessorMixin | None,
|
||||||
preprocess_iterable: bool,
|
preprocess_iterable: bool,
|
||||||
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
|
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
|
||||||
"""Prepare standard (non-pretraining) datasets."""
|
"""Prepare standard (non-pretraining) datasets."""
|
||||||
|
|
||||||
def _load_datasets():
|
def _load_datasets():
|
||||||
@@ -118,7 +120,14 @@ def _prepare_standard_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Calculate total number of training steps
|
# Calculate total number of training steps
|
||||||
if cfg.max_steps:
|
# For streaming datasets, we must use max_steps
|
||||||
|
if isinstance(train_dataset, IterableDataset):
|
||||||
|
if not cfg.max_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"When using streaming datasets, you must set max_steps in your config"
|
||||||
|
)
|
||||||
|
total_num_steps = cfg.max_steps
|
||||||
|
elif cfg.max_steps:
|
||||||
total_num_steps = min(
|
total_num_steps = min(
|
||||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||||
)
|
)
|
||||||
@@ -342,14 +351,18 @@ def _load_raw_datasets(
|
|||||||
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
else:
|
else:
|
||||||
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
|
||||||
|
# Skip packing processing for streaming datasets - they handle it differently
|
||||||
|
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
# Save the prepared dataset
|
# Skip saving for streaming datasets as they can't be cached
|
||||||
dataset_hash = generate_dataset_hash_from_config(
|
if not isinstance(dataset, IterableDataset):
|
||||||
cfg, datasets_configs, tokenizer.name_or_path
|
# Save the prepared dataset
|
||||||
)
|
dataset_hash = generate_dataset_hash_from_config(
|
||||||
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
cfg, datasets_configs, tokenizer.name_or_path
|
||||||
|
)
|
||||||
|
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
|
||||||
|
|
||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|
||||||
@@ -365,8 +378,10 @@ def _load_and_process_single_dataset(
|
|||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
# Load the dataset
|
# Load the dataset
|
||||||
|
# Use streaming if enabled in config or if using iterable preprocessing
|
||||||
|
use_streaming = cfg.streaming or preprocess_iterable
|
||||||
dataset = load_dataset_with_config(
|
dataset = load_dataset_with_config(
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
|
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse dataset type
|
# Parse dataset type
|
||||||
@@ -391,16 +406,63 @@ def _load_and_process_single_dataset(
|
|||||||
num_shards=dataset_config.shards, index=shards_idx
|
num_shards=dataset_config.shards, index=shards_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply dataset wrapper
|
# For streaming datasets, we need to handle tokenization differently
|
||||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
if isinstance(dataset, IterableDataset):
|
||||||
dataset_config=dataset_config,
|
# Use pretraining's approach for multipack streaming
|
||||||
tokenizer=tokenizer,
|
if cfg.sample_packing:
|
||||||
cfg=cfg,
|
# Create the dataset wrapper function once
|
||||||
dataset_base_type=d_base_type,
|
def ds_wrapper_fn(dataset=None):
|
||||||
dataset=dataset,
|
wrapped_dataset, prompter = get_dataset_wrapper(
|
||||||
dataset_prompt_style=d_prompt_style,
|
dataset_config=dataset_config,
|
||||||
processor=processor,
|
tokenizer=tokenizer,
|
||||||
)
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
return wrapped_dataset, prompter
|
||||||
|
|
||||||
|
# Use optimized streaming wrapper to avoid repeated preprocessing logs
|
||||||
|
dataset_wrapper = wrap_streaming_sft_dataset_optimized(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
max_tokens=cfg.sequence_len,
|
||||||
|
batch_size=max(
|
||||||
|
1, cfg.sequence_len // 512
|
||||||
|
), # Estimate sequences per pack
|
||||||
|
seed=cfg.seed or 42,
|
||||||
|
buffer_size=cfg.pretrain_multipack_buffer_size or 1_000,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use regular streaming wrapper
|
||||||
|
dataset_wrapper = wrap_streaming_sft_dataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type,
|
||||||
|
d_prompt_style,
|
||||||
|
processor,
|
||||||
|
max_tokens=cfg.sequence_len,
|
||||||
|
buffer_size=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For streaming, we don't have a specific prompter
|
||||||
|
dataset_prompter = None
|
||||||
|
else:
|
||||||
|
# Apply dataset wrapper for regular datasets
|
||||||
|
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
return dataset_wrapper, dataset_prompter
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|||||||
@@ -524,7 +524,9 @@ def generate_dataset_hash_from_config(
|
|||||||
return str(md5(config_str))
|
return str(md5(config_str))
|
||||||
|
|
||||||
|
|
||||||
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
def merge_datasets(
|
||||||
|
datasets: list[Dataset | IterableDataset], cfg: DictDefault
|
||||||
|
) -> Dataset | IterableDataset:
|
||||||
"""Merge multiple datasets into one with optional shuffling.
|
"""Merge multiple datasets into one with optional shuffling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -534,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
Returns:
|
Returns:
|
||||||
Merged dataset.
|
Merged dataset.
|
||||||
"""
|
"""
|
||||||
|
# Check if we're dealing with streaming datasets
|
||||||
|
if any(isinstance(ds, IterableDataset) for ds in datasets):
|
||||||
|
# All datasets must be streaming for merging
|
||||||
|
if not all(isinstance(ds, IterableDataset) for ds in datasets):
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot mix streaming and non-streaming datasets. "
|
||||||
|
"Either all datasets must be streaming or none."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(datasets) == 1:
|
||||||
|
ds = datasets[0]
|
||||||
|
# Streaming datasets handle shuffling differently
|
||||||
|
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
|
||||||
|
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
# Merge streaming datasets
|
||||||
|
LOG.info("Merging streaming datasets...")
|
||||||
|
from datasets import interleave_datasets
|
||||||
|
|
||||||
|
# For streaming, we interleave datasets instead of concatenating
|
||||||
|
merged_dataset = interleave_datasets(datasets)
|
||||||
|
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.debug("Shuffling merged streaming datasets...")
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
LOG.warning(
|
||||||
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||||
|
"This will randomize the order of samples."
|
||||||
|
)
|
||||||
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
|
||||||
|
|
||||||
|
return merged_dataset
|
||||||
|
|
||||||
|
# Original logic for non-streaming datasets
|
||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
ds = datasets[0]
|
ds = datasets[0]
|
||||||
|
|
||||||
|
|||||||
150
src/axolotl/utils/data/streaming.py
Normal file
150
src/axolotl/utils/data/streaming.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Utilities for handling streaming datasets."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import add_position_ids
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_streaming_sft_dataset(
|
||||||
|
dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
buffer_size: int = 10_000,
|
||||||
|
) -> IterableDataset:
|
||||||
|
"""
|
||||||
|
Wrap a streaming SFT dataset with tokenization and optional packing.
|
||||||
|
|
||||||
|
This is similar to wrap_pretraining_dataset but for SFT datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The streaming dataset to wrap
|
||||||
|
tokenizer: Tokenizer to use
|
||||||
|
cfg: Configuration object
|
||||||
|
dataset_config: Dataset configuration
|
||||||
|
d_base_type: Base dataset type
|
||||||
|
d_prompt_style: Prompt style
|
||||||
|
processor: Optional processor for multimodal
|
||||||
|
max_tokens: Maximum sequence length
|
||||||
|
buffer_size: Buffer size for shuffling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped streaming dataset ready for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
|
|
||||||
|
# Apply shuffling if configured
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# For streaming datasets, we need to get column names from the first sample
|
||||||
|
remove_columns = []
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = list(first_row.keys())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset dataset after peeking
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Define the encoding function - always add position_ids for compatibility
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# For sample packing, we need to handle position_ids
|
||||||
|
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
|
||||||
|
"""Encode examples for streaming with sample packing."""
|
||||||
|
# Convert the batch dict to a temporary Dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict(examples)
|
||||||
|
|
||||||
|
# Apply the dataset wrapper to tokenize
|
||||||
|
wrapped_dataset, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=temp_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to dict for processing
|
||||||
|
result = {}
|
||||||
|
if hasattr(wrapped_dataset, "to_dict"):
|
||||||
|
result = wrapped_dataset.to_dict()
|
||||||
|
else:
|
||||||
|
for key in wrapped_dataset.column_names:
|
||||||
|
result[key] = wrapped_dataset[key]
|
||||||
|
|
||||||
|
# Add position_ids using the existing function
|
||||||
|
result = add_position_ids(result)
|
||||||
|
|
||||||
|
# For multipack attention, we may need to drop attention_mask
|
||||||
|
if cfg.pretrain_multipack_attn and "attention_mask" in result:
|
||||||
|
del result["attention_mask"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
encode_fn = encode_streaming_packed
|
||||||
|
else:
|
||||||
|
# Regular encoding without packing - still add position_ids for compatibility
|
||||||
|
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
|
||||||
|
"""Encode examples for streaming."""
|
||||||
|
# Convert the batch dict to a temporary Dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict(examples)
|
||||||
|
|
||||||
|
# Apply the dataset wrapper to tokenize
|
||||||
|
wrapped_dataset, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=temp_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to dict format
|
||||||
|
result = {}
|
||||||
|
if hasattr(wrapped_dataset, "to_dict"):
|
||||||
|
result = wrapped_dataset.to_dict()
|
||||||
|
else:
|
||||||
|
for key in wrapped_dataset.column_names:
|
||||||
|
result[key] = wrapped_dataset[key]
|
||||||
|
|
||||||
|
# Add position_ids even without packing for compatibility
|
||||||
|
result = add_position_ids(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
encode_fn = encode_streaming
|
||||||
|
|
||||||
|
# Map the encoding function over the streaming dataset
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode_fn,
|
||||||
|
batched=True,
|
||||||
|
batch_size=buffer_size,
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set format for PyTorch
|
||||||
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
|
return dataset
|
||||||
219
src/axolotl/utils/data/streaming_multipack.py
Normal file
219
src/axolotl/utils/data/streaming_multipack.py
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
"""Streaming dataset with multipack support for SFT."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_packed_sft_streaming(
|
||||||
|
collate_fn,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
dataset_config,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_seq_length: int = 2048,
|
||||||
|
batch_size: int = 4,
|
||||||
|
multipack_attn: Optional[bool] = True,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
"""
|
||||||
|
Encode and pack streaming SFT data similar to how pretraining does it.
|
||||||
|
|
||||||
|
This function:
|
||||||
|
1. Tokenizes the examples using the dataset wrapper
|
||||||
|
2. Adds position_ids for each sequence
|
||||||
|
3. Uses MultipackBatchSampler to efficiently pack sequences
|
||||||
|
4. Applies the collator to handle attention masks properly
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collate_fn: Collator function for handling batches
|
||||||
|
ds_wrapper_fn: Function to get the dataset wrapper for tokenization
|
||||||
|
examples: Dict of lists containing the raw examples
|
||||||
|
dataset_config: Configuration for the dataset
|
||||||
|
tokenizer: Tokenizer to use
|
||||||
|
cfg: Main configuration
|
||||||
|
d_base_type: Dataset base type
|
||||||
|
d_prompt_style: Prompt style
|
||||||
|
processor: Optional processor for multimodal
|
||||||
|
max_seq_length: Maximum sequence length
|
||||||
|
batch_size: Batch size for packing
|
||||||
|
multipack_attn: Whether to use multipack attention
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of packed and processed data ready for training
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
|
|
||||||
|
# Convert examples to Dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict(examples)
|
||||||
|
|
||||||
|
# Apply the dataset wrapper to tokenize
|
||||||
|
train_dataset, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=temp_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process for packing - add position_ids and filter long sequences
|
||||||
|
train_dataset = process_pretraining_datasets_for_packing(
|
||||||
|
train_dataset,
|
||||||
|
max_seq_length,
|
||||||
|
skip_position_ids=not multipack_attn,
|
||||||
|
drop_attention_mask=multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use MultipackBatchSampler to create efficient packed batches
|
||||||
|
sampler = MultipackBatchSampler(
|
||||||
|
sampler=RandomSampler(train_dataset),
|
||||||
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
|
batch_size=1, # We pack multiple sequences into one "batch"
|
||||||
|
batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch
|
||||||
|
drop_last=True,
|
||||||
|
num_processes=1, # Single process for streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect packed data
|
||||||
|
chunked_data = defaultdict(list)
|
||||||
|
|
||||||
|
for batch in sampler:
|
||||||
|
for data in batch:
|
||||||
|
features = train_dataset[data]
|
||||||
|
|
||||||
|
# Clean up unnecessary fields
|
||||||
|
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
|
||||||
|
if field in features:
|
||||||
|
del features[field]
|
||||||
|
|
||||||
|
# Ensure labels exist
|
||||||
|
if "labels" not in features:
|
||||||
|
features["labels"] = features["input_ids"].copy()
|
||||||
|
|
||||||
|
# Apply collator to handle padding and attention masks
|
||||||
|
collated_features = collate_fn(features)
|
||||||
|
|
||||||
|
# Collect features
|
||||||
|
for feature in features.keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||||
|
|
||||||
|
return chunked_data
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_streaming_sft_dataset_with_packing(
|
||||||
|
dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
buffer_size: int = 10_000,
|
||||||
|
) -> IterableDataset:
|
||||||
|
"""
|
||||||
|
Wrap a streaming SFT dataset with tokenization and multipack batching.
|
||||||
|
|
||||||
|
This creates properly packed batches with:
|
||||||
|
- Multiple sequences concatenated together
|
||||||
|
- Position IDs that reset for each sequence
|
||||||
|
- Attention masks that prevent cross-attention between sequences
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The streaming dataset to wrap
|
||||||
|
tokenizer: Tokenizer to use
|
||||||
|
cfg: Configuration object
|
||||||
|
dataset_config: Dataset configuration
|
||||||
|
d_base_type: Base dataset type
|
||||||
|
d_prompt_style: Prompt style
|
||||||
|
processor: Optional processor for multimodal
|
||||||
|
max_tokens: Maximum sequence length
|
||||||
|
buffer_size: Buffer size for streaming/shuffling
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped streaming dataset with multipack batching
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Apply shuffling if configured
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Get column names from first sample
|
||||||
|
remove_columns = []
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = list(first_row.keys())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset dataset after peeking
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Create the collator for multipack
|
||||||
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=max_tokens,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the encoding function
|
||||||
|
# batch_size here refers to how many sequences to pack together to fill max_tokens
|
||||||
|
# The actual batching happens at the DataLoader level with micro_batch_size=1
|
||||||
|
pack_batch_size = max(
|
||||||
|
1, max_tokens // 512
|
||||||
|
) # Estimate based on typical sequence lengths
|
||||||
|
|
||||||
|
encode_fn = functools.partial(
|
||||||
|
encode_packed_sft_streaming,
|
||||||
|
collate_fn,
|
||||||
|
None, # ds_wrapper_fn will be created inside
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
d_base_type=d_base_type,
|
||||||
|
d_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
max_seq_length=max_tokens,
|
||||||
|
batch_size=pack_batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map the encoding function over the streaming dataset
|
||||||
|
# This will process data in batches and apply packing
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode_fn,
|
||||||
|
batched=True,
|
||||||
|
batch_size=buffer_size, # Process large batches for efficiency
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set format for PyTorch
|
||||||
|
dataset = dataset.with_format("torch")
|
||||||
|
|
||||||
|
# IMPORTANT: Set micro_batch_size to 1 since we've already packed
|
||||||
|
# This prevents the trainer from trying to batch our packed sequences
|
||||||
|
cfg.micro_batch_size = 1
|
||||||
|
|
||||||
|
return dataset
|
||||||
158
src/axolotl/utils/data/streaming_sft.py
Normal file
158
src/axolotl/utils/data/streaming_sft.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Optimized streaming SFT with multipack support that avoids repeated preprocessing."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import add_position_ids, drop_long_seq
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_packed_streaming_sft(
|
||||||
|
collate_fn,
|
||||||
|
ds_wrapper: Callable,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
max_seq_length: int = 2048,
|
||||||
|
batch_size: int = 4,
|
||||||
|
multipack_attn: Optional[bool] = True,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
"""
|
||||||
|
Encode streaming SFT data with packing, avoiding repeated preprocessing logs.
|
||||||
|
|
||||||
|
This is similar to encode_packed_pretraining but skips the verbose
|
||||||
|
process_pretraining_datasets_for_packing call that logs repeatedly.
|
||||||
|
"""
|
||||||
|
# Tokenize all the examples
|
||||||
|
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
|
||||||
|
|
||||||
|
# Apply filtering and preprocessing directly without verbose logging
|
||||||
|
# Filter out long sequences
|
||||||
|
def should_keep(sample):
|
||||||
|
return drop_long_seq(sample, sequence_len=max_seq_length)
|
||||||
|
|
||||||
|
# Convert to list for filtering (since we need to iterate anyway)
|
||||||
|
filtered_samples = []
|
||||||
|
for i in range(len(train_dataset)):
|
||||||
|
sample = train_dataset[i]
|
||||||
|
if should_keep(sample):
|
||||||
|
# Add position_ids if needed
|
||||||
|
if not multipack_attn: # skip_position_ids=False when multipack_attn=True
|
||||||
|
sample = add_position_ids(sample)
|
||||||
|
filtered_samples.append(sample)
|
||||||
|
|
||||||
|
# Convert back to dataset
|
||||||
|
if not filtered_samples:
|
||||||
|
return {"input_ids": [], "labels": [], "attention_mask": []}
|
||||||
|
|
||||||
|
train_dataset = Dataset.from_list(filtered_samples)
|
||||||
|
|
||||||
|
# Remove attention_mask if needed for multipack
|
||||||
|
if multipack_attn and "attention_mask" in train_dataset.column_names:
|
||||||
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
|
|
||||||
|
# Use MultipackBatchSampler to create efficient packed batches
|
||||||
|
sampler = MultipackBatchSampler(
|
||||||
|
sampler=RandomSampler(train_dataset),
|
||||||
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
|
batch_size=1,
|
||||||
|
batch_max_len=batch_size * max_seq_length,
|
||||||
|
drop_last=True,
|
||||||
|
num_processes=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect packed data
|
||||||
|
chunked_data = defaultdict(list)
|
||||||
|
|
||||||
|
for batch in sampler:
|
||||||
|
for data in batch:
|
||||||
|
features = train_dataset[data]
|
||||||
|
# Clean up unnecessary fields
|
||||||
|
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
|
||||||
|
if field in features:
|
||||||
|
del features[field]
|
||||||
|
# Ensure labels exist
|
||||||
|
if "labels" not in features:
|
||||||
|
features["labels"] = features["input_ids"].copy()
|
||||||
|
# Apply collator
|
||||||
|
collated_features = collate_fn(features)
|
||||||
|
|
||||||
|
# Collect features
|
||||||
|
for feature in features.keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||||
|
|
||||||
|
return chunked_data
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_streaming_sft_dataset_optimized(
|
||||||
|
dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
ds_wrapper_fn: Callable,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
batch_size: int = 4,
|
||||||
|
seed: int = 42,
|
||||||
|
buffer_size: int = 1000,
|
||||||
|
) -> IterableDataset:
|
||||||
|
"""
|
||||||
|
Wrap a streaming SFT dataset with optimized multipack batching.
|
||||||
|
|
||||||
|
This avoids the repeated preprocessing logs by using a custom encoder
|
||||||
|
that applies filtering and position_ids directly.
|
||||||
|
"""
|
||||||
|
# Create collator for packing
|
||||||
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=max_tokens,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create optimized encode function
|
||||||
|
encode = functools.partial(
|
||||||
|
encode_packed_streaming_sft,
|
||||||
|
collate_fn,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
max_seq_length=max_tokens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply shuffling if configured
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
|
else:
|
||||||
|
LOG.debug("NOT shuffling merged streaming datasets")
|
||||||
|
|
||||||
|
# Get column names to remove
|
||||||
|
remove_columns = []
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = list(first_row.keys())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset dataset after peeking
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Map the optimized encoding function
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode,
|
||||||
|
batched=True,
|
||||||
|
batch_size=buffer_size,
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set micro_batch_size to 1 since we've already packed
|
||||||
|
cfg.micro_batch_size = 1
|
||||||
|
|
||||||
|
return dataset
|
||||||
211
src/axolotl/utils/data/streaming_sft_simple.py
Normal file
211
src/axolotl/utils/data/streaming_sft_simple.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""Simple streaming SFT with multipack support."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingMultipackDataset:
|
||||||
|
"""Dataset that handles streaming with multipack on-the-fly."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
pack_length: int = 4, # How many sequences to collect before packing
|
||||||
|
):
|
||||||
|
self.base_dataset = base_dataset
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.cfg = cfg
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.pack_length = pack_length
|
||||||
|
|
||||||
|
# Create the dataset wrapper once
|
||||||
|
from axolotl.utils.data.wrappers import get_dataset_wrapper
|
||||||
|
|
||||||
|
# Create a dummy dataset to get the wrapper
|
||||||
|
dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]}
|
||||||
|
dummy_dataset = Dataset.from_dict(dummy_data)
|
||||||
|
|
||||||
|
self.dataset_wrapper, _ = get_dataset_wrapper(
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
cfg=cfg,
|
||||||
|
dataset_base_type=d_base_type,
|
||||||
|
dataset=dummy_dataset,
|
||||||
|
dataset_prompt_style=d_prompt_style,
|
||||||
|
processor=processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create collator for packing
|
||||||
|
self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=max_tokens,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
"""Iterator that yields packed samples."""
|
||||||
|
buffer = []
|
||||||
|
|
||||||
|
for sample in self.base_dataset:
|
||||||
|
# Convert single sample to dataset for processing
|
||||||
|
temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()})
|
||||||
|
|
||||||
|
# Tokenize using the dataset wrapper
|
||||||
|
try:
|
||||||
|
tokenized = self.dataset_wrapper.__class__(
|
||||||
|
temp_dataset,
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in self.dataset_wrapper.__dict__.items()
|
||||||
|
if not k.startswith("_")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the tokenized sample
|
||||||
|
if len(tokenized) > 0:
|
||||||
|
tokenized_sample = tokenized[0]
|
||||||
|
|
||||||
|
# Add to buffer
|
||||||
|
buffer.append(tokenized_sample)
|
||||||
|
|
||||||
|
# When buffer is full, pack and yield
|
||||||
|
if len(buffer) >= self.pack_length:
|
||||||
|
packed_sample = self._pack_buffer(buffer)
|
||||||
|
if packed_sample:
|
||||||
|
yield packed_sample
|
||||||
|
buffer = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to process sample: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process remaining buffer
|
||||||
|
if buffer:
|
||||||
|
packed_sample = self._pack_buffer(buffer)
|
||||||
|
if packed_sample:
|
||||||
|
yield packed_sample
|
||||||
|
|
||||||
|
def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]:
|
||||||
|
"""Pack a buffer of tokenized samples."""
|
||||||
|
if not buffer:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create dataset from buffer
|
||||||
|
temp_dataset = Dataset.from_list(buffer)
|
||||||
|
|
||||||
|
# Add position_ids and process for packing
|
||||||
|
temp_dataset = process_pretraining_datasets_for_packing(
|
||||||
|
temp_dataset,
|
||||||
|
self.max_tokens,
|
||||||
|
skip_position_ids=not self.cfg.pretrain_multipack_attn,
|
||||||
|
drop_attention_mask=self.cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use MultipackBatchSampler to create packed batches
|
||||||
|
sampler = MultipackBatchSampler(
|
||||||
|
sampler=RandomSampler(temp_dataset),
|
||||||
|
lengths=get_dataset_lengths(temp_dataset),
|
||||||
|
batch_size=1,
|
||||||
|
batch_max_len=self.max_tokens,
|
||||||
|
drop_last=False,
|
||||||
|
num_processes=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get packed data
|
||||||
|
for batch in sampler:
|
||||||
|
if batch and batch[0]: # Check if we have data
|
||||||
|
features = []
|
||||||
|
for idx in batch[0]: # batch[0] contains the indices
|
||||||
|
sample = temp_dataset[idx]
|
||||||
|
if "labels" not in sample:
|
||||||
|
sample["labels"] = sample["input_ids"].copy()
|
||||||
|
features.append(sample)
|
||||||
|
|
||||||
|
# Apply collator to create final packed sample
|
||||||
|
if features:
|
||||||
|
packed = self.collator(features)
|
||||||
|
return {
|
||||||
|
k: v.squeeze(0) if v.dim() > 1 else v
|
||||||
|
for k, v in packed.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning(f"Failed to pack buffer: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_streaming_sft_dataset_simple(
|
||||||
|
dataset: IterableDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type: str,
|
||||||
|
d_prompt_style: str | None,
|
||||||
|
processor: Any | None,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
buffer_size: int = 10_000,
|
||||||
|
) -> IterableDataset:
|
||||||
|
"""
|
||||||
|
Wrap a streaming SFT dataset with simple multipack batching.
|
||||||
|
|
||||||
|
This approach processes samples in small groups rather than large batches,
|
||||||
|
avoiding the repeated processing issue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Apply shuffling if configured
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
|
||||||
|
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
|
||||||
|
|
||||||
|
# Create the streaming multipack dataset
|
||||||
|
multipack_dataset = StreamingMultipackDataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
dataset_config,
|
||||||
|
d_base_type,
|
||||||
|
d_prompt_style,
|
||||||
|
processor,
|
||||||
|
max_tokens,
|
||||||
|
pack_length=max(1, max_tokens // 512), # Estimate sequences per pack
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to IterableDataset
|
||||||
|
class IterableWrapper:
|
||||||
|
def __init__(self, dataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.dataset)
|
||||||
|
|
||||||
|
wrapped = IterableWrapper(multipack_dataset)
|
||||||
|
|
||||||
|
# Set micro_batch_size to 1 since sequences are already packed
|
||||||
|
cfg.micro_batch_size = 1
|
||||||
|
|
||||||
|
return wrapped
|
||||||
@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
|
|
||||||
|
|
||||||
def handle_long_seq_in_dataset(
|
def handle_long_seq_in_dataset(
|
||||||
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
|
||||||
) -> Dataset:
|
) -> Dataset | IterableDataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -190,7 +190,14 @@ def handle_long_seq_in_dataset(
|
|||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences removed.
|
||||||
"""
|
"""
|
||||||
if "input_ids" not in dataset.column_names:
|
# Streaming datasets don't support filtering the same way
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
LOG.info(
|
||||||
|
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
|
||||||
"expected for reward modeling."
|
"expected for reward modeling."
|
||||||
|
|||||||
@@ -244,6 +244,12 @@ class AxolotlInputConfig(
|
|||||||
dataloader_num_workers: int | None = None
|
dataloader_num_workers: int | None = None
|
||||||
dataloader_prefetch_factor: int | None = None
|
dataloader_prefetch_factor: int | None = None
|
||||||
dataloader_drop_last: bool | None = None
|
dataloader_drop_last: bool | None = None
|
||||||
|
streaming: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
accelerator_config: dict[str, Any] | None = None
|
accelerator_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -1074,6 +1074,24 @@ class PretrainingValidationMixin:
|
|||||||
data["accelerator_config"]["dispatch_batches"] = False
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_pretraining_split_batches_accelerate(cls, data):
|
||||||
|
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||||
|
if data.get("streaming"):
|
||||||
|
accelerator_config = data.get("accelerator_config", {})
|
||||||
|
if not accelerator_config:
|
||||||
|
data["accelerator_config"] = {
|
||||||
|
"split_batches": False,
|
||||||
|
"dispatch_batches": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if accelerator_config.get("split_batches") is None:
|
||||||
|
data["accelerator_config"]["split_batches"] = False
|
||||||
|
if accelerator_config.get("dispatch_batches") is None:
|
||||||
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ModelCompatibilityValidationMixin:
|
class ModelCompatibilityValidationMixin:
|
||||||
"""Validation methods for specific model compatibility."""
|
"""Validation methods for specific model compatibility."""
|
||||||
|
|||||||
Reference in New Issue
Block a user