From 2e2302aae3c0802e32cb728ae713bf8d1a464da1 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 25 Aug 2025 15:46:25 +0000 Subject: [PATCH] remove unused --- examples/streaming/streaming-pretrain.yml | 61 ----- examples/streaming/streaming-sft.yml | 52 ----- src/axolotl/utils/data/streaming_multipack.py | 219 ------------------ .../utils/data/streaming_sft_simple.py | 211 ----------------- 4 files changed, 543 deletions(-) delete mode 100644 examples/streaming/streaming-pretrain.yml delete mode 100644 examples/streaming/streaming-sft.yml delete mode 100644 src/axolotl/utils/data/streaming_multipack.py delete mode 100644 src/axolotl/utils/data/streaming_sft_simple.py diff --git a/examples/streaming/streaming-pretrain.yml b/examples/streaming/streaming-pretrain.yml deleted file mode 100644 index fdfba8621..000000000 --- a/examples/streaming/streaming-pretrain.yml +++ /dev/null @@ -1,61 +0,0 @@ -# Example configuration for streaming pretraining -# This demonstrates how to pretrain on large datasets that don't fit in memory - -base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 -model_type: LlamaForCausalLM -tokenizer_type: LlamaTokenizer - -# Required: max_steps for streaming pretraining -max_steps: 10000 - -# Pretraining dataset configuration -# These are automatically streamed -pretraining_dataset: - - path: allenai/c4 - name: en - type: pretrain - # Optional: skip N samples (useful for resuming) - # skip: 1000000 - -# Can also use multiple pretraining datasets -# pretraining_dataset: -# - path: allenai/c4 -# name: en -# type: pretrain -# - path: HuggingFaceFW/fineweb -# type: pretrain - -val_set_size: 0.0 - -# Sequence and packing configuration -sequence_len: 2048 -sample_packing: true -pretrain_multipack_attn: true -pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching - -# Training hyperparameters -gradient_accumulation_steps: 8 -micro_batch_size: 4 -optimizer: adamw_torch -lr_scheduler: cosine -learning_rate: 3e-4 - -# Memory optimizations -bf16: auto -tf32: false -gradient_checkpointing: true -flash_attention: true - -# Checkpointing and logging -output_dir: ./outputs/pretrain-streaming -logging_steps: 10 -save_steps: 500 -save_total_limit: 3 # Keep only last 3 checkpoints - -# Warmup -warmup_ratio: 0.1 - -# Optional: enable wandb for monitoring -# wandb_project: streaming-pretrain -# wandb_entity: your-entity -# wandb_name: c4-pretrain diff --git a/examples/streaming/streaming-sft.yml b/examples/streaming/streaming-sft.yml deleted file mode 100644 index 10b1d812a..000000000 --- a/examples/streaming/streaming-sft.yml +++ /dev/null @@ -1,52 +0,0 @@ -# Example configuration for streaming SFT training -# This enables training on datasets larger than memory by streaming them from HuggingFace Hub - -base_model: HuggingFaceTB/SmolLM2-135M - -# Enable streaming mode for datasets -streaming: true - -# When using streaming, max_steps is required -max_steps: 3 # Just test a few steps - -# Training datasets - these will be streamed -# datasets: -# - path: tatsu-lab/alpaca -# type: alpaca -# split: train - -pretraining_dataset: - - path: tatsu-lab/alpaca - type: alpaca - split: train - -# Dataset configuration -sequence_len: 2048 -sample_packing: true # Enable multipack batching -pretrain_multipack_attn: true # Enable multipack attention masking -pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT) -special_tokens: - pad_token: <|endoftext|> - -# Training hyperparameters -gradient_accumulation_steps: 4 -micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples -num_epochs: 1 # With streaming, typically use max_steps instead -optimizer: adamw_torch -lr_scheduler: cosine -learning_rate: 2e-5 - -# Enable efficient training -bf16: auto -tf32: false -gradient_checkpointing: true -flash_attention: true # Enable flash attention with multipack - -# Logging and checkpointing -logging_steps: 10 -eval_steps: 100 -save_steps: 200 -output_dir: ./outputs/streaming-model - -# Warmup -warmup_steps: 100 diff --git a/src/axolotl/utils/data/streaming_multipack.py b/src/axolotl/utils/data/streaming_multipack.py deleted file mode 100644 index a67d4f815..000000000 --- a/src/axolotl/utils/data/streaming_multipack.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Streaming dataset with multipack support for SFT.""" - -import functools -from collections import defaultdict -from typing import Any, Dict, List, Optional - -import numpy as np -from datasets import Dataset, IterableDataset -from torch.utils.data import RandomSampler -from transformers import PreTrainedTokenizerBase - -from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.logging import get_logger -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.trainer import process_pretraining_datasets_for_packing - -LOG = get_logger(__name__) - - -def encode_packed_sft_streaming( - collate_fn, - ds_wrapper_fn, - examples: Dict[str, List], - dataset_config, - tokenizer, - cfg, - d_base_type: str, - d_prompt_style: str | None, - processor: Any | None, - max_seq_length: int = 2048, - batch_size: int = 4, - multipack_attn: Optional[bool] = True, -) -> Dict[str, List]: - """ - Encode and pack streaming SFT data similar to how pretraining does it. - - This function: - 1. Tokenizes the examples using the dataset wrapper - 2. Adds position_ids for each sequence - 3. Uses MultipackBatchSampler to efficiently pack sequences - 4. Applies the collator to handle attention masks properly - - Args: - collate_fn: Collator function for handling batches - ds_wrapper_fn: Function to get the dataset wrapper for tokenization - examples: Dict of lists containing the raw examples - dataset_config: Configuration for the dataset - tokenizer: Tokenizer to use - cfg: Main configuration - d_base_type: Dataset base type - d_prompt_style: Prompt style - processor: Optional processor for multimodal - max_seq_length: Maximum sequence length - batch_size: Batch size for packing - multipack_attn: Whether to use multipack attention - - Returns: - Dict of packed and processed data ready for training - """ - # Import here to avoid circular imports - from axolotl.utils.data.wrappers import get_dataset_wrapper - - # Convert examples to Dataset for processing - temp_dataset = Dataset.from_dict(examples) - - # Apply the dataset wrapper to tokenize - train_dataset, _ = get_dataset_wrapper( - dataset_config=dataset_config, - tokenizer=tokenizer, - cfg=cfg, - dataset_base_type=d_base_type, - dataset=temp_dataset, - dataset_prompt_style=d_prompt_style, - processor=processor, - ) - - # Process for packing - add position_ids and filter long sequences - train_dataset = process_pretraining_datasets_for_packing( - train_dataset, - max_seq_length, - skip_position_ids=not multipack_attn, - drop_attention_mask=multipack_attn, - ) - - # Use MultipackBatchSampler to create efficient packed batches - sampler = MultipackBatchSampler( - sampler=RandomSampler(train_dataset), - lengths=get_dataset_lengths(train_dataset), - batch_size=1, # We pack multiple sequences into one "batch" - batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch - drop_last=True, - num_processes=1, # Single process for streaming - ) - - # Collect packed data - chunked_data = defaultdict(list) - - for batch in sampler: - for data in batch: - features = train_dataset[data] - - # Clean up unnecessary fields - for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]: - if field in features: - del features[field] - - # Ensure labels exist - if "labels" not in features: - features["labels"] = features["input_ids"].copy() - - # Apply collator to handle padding and attention masks - collated_features = collate_fn(features) - - # Collect features - for feature in features.keys(): - if feature == "length": - continue - chunked_data[feature].append(collated_features[feature].squeeze(0)) - - return chunked_data - - -def wrap_streaming_sft_dataset_with_packing( - dataset: IterableDataset, - tokenizer: PreTrainedTokenizerBase, - cfg, - dataset_config, - d_base_type: str, - d_prompt_style: str | None, - processor: Any | None, - max_tokens: int = 2048, - buffer_size: int = 10_000, -) -> IterableDataset: - """ - Wrap a streaming SFT dataset with tokenization and multipack batching. - - This creates properly packed batches with: - - Multiple sequences concatenated together - - Position IDs that reset for each sequence - - Attention masks that prevent cross-attention between sequences - - Args: - dataset: The streaming dataset to wrap - tokenizer: Tokenizer to use - cfg: Configuration object - dataset_config: Dataset configuration - d_base_type: Base dataset type - d_prompt_style: Prompt style - processor: Optional processor for multimodal - max_tokens: Maximum sequence length - buffer_size: Buffer size for streaming/shuffling - - Returns: - Wrapped streaming dataset with multipack batching - """ - - # Apply shuffling if configured - if cfg.shuffle_merged_datasets: - LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}") - dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) - - # Get column names from first sample - remove_columns = [] - for first_row in dataset: - remove_columns = list(first_row.keys()) - break - - # Reset dataset after peeking - if cfg.shuffle_merged_datasets: - dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) - - # Create the collator for multipack - collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - padding=True, - pad_to_multiple_of=max_tokens, - multipack_attn=cfg.pretrain_multipack_attn, - ) - - # Create the encoding function - # batch_size here refers to how many sequences to pack together to fill max_tokens - # The actual batching happens at the DataLoader level with micro_batch_size=1 - pack_batch_size = max( - 1, max_tokens // 512 - ) # Estimate based on typical sequence lengths - - encode_fn = functools.partial( - encode_packed_sft_streaming, - collate_fn, - None, # ds_wrapper_fn will be created inside - dataset_config=dataset_config, - tokenizer=tokenizer, - cfg=cfg, - d_base_type=d_base_type, - d_prompt_style=d_prompt_style, - processor=processor, - max_seq_length=max_tokens, - batch_size=pack_batch_size, - multipack_attn=cfg.pretrain_multipack_attn, - ) - - # Map the encoding function over the streaming dataset - # This will process data in batches and apply packing - dataset = dataset.map( - encode_fn, - batched=True, - batch_size=buffer_size, # Process large batches for efficiency - remove_columns=remove_columns, - ) - - # Set format for PyTorch - dataset = dataset.with_format("torch") - - # IMPORTANT: Set micro_batch_size to 1 since we've already packed - # This prevents the trainer from trying to batch our packed sequences - cfg.micro_batch_size = 1 - - return dataset diff --git a/src/axolotl/utils/data/streaming_sft_simple.py b/src/axolotl/utils/data/streaming_sft_simple.py deleted file mode 100644 index 045ded991..000000000 --- a/src/axolotl/utils/data/streaming_sft_simple.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Simple streaming SFT with multipack support.""" - -import functools -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional - -import numpy as np -from datasets import Dataset, IterableDataset -from torch.utils.data import RandomSampler -from transformers import PreTrainedTokenizerBase - -from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.logging import get_logger -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.trainer import process_pretraining_datasets_for_packing - -LOG = get_logger(__name__) - - -class StreamingMultipackDataset: - """Dataset that handles streaming with multipack on-the-fly.""" - - def __init__( - self, - base_dataset: IterableDataset, - tokenizer: PreTrainedTokenizerBase, - cfg, - dataset_config, - d_base_type: str, - d_prompt_style: str | None, - processor: Any | None, - max_tokens: int = 2048, - pack_length: int = 4, # How many sequences to collect before packing - ): - self.base_dataset = base_dataset - self.tokenizer = tokenizer - self.cfg = cfg - self.max_tokens = max_tokens - self.pack_length = pack_length - - # Create the dataset wrapper once - from axolotl.utils.data.wrappers import get_dataset_wrapper - - # Create a dummy dataset to get the wrapper - dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]} - dummy_dataset = Dataset.from_dict(dummy_data) - - self.dataset_wrapper, _ = get_dataset_wrapper( - dataset_config=dataset_config, - tokenizer=tokenizer, - cfg=cfg, - dataset_base_type=d_base_type, - dataset=dummy_dataset, - dataset_prompt_style=d_prompt_style, - processor=processor, - ) - - # Create collator for packing - self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - padding=True, - pad_to_multiple_of=max_tokens, - multipack_attn=cfg.pretrain_multipack_attn, - ) - - def __iter__(self): - """Iterator that yields packed samples.""" - buffer = [] - - for sample in self.base_dataset: - # Convert single sample to dataset for processing - temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()}) - - # Tokenize using the dataset wrapper - try: - tokenized = self.dataset_wrapper.__class__( - temp_dataset, - **{ - k: v - for k, v in self.dataset_wrapper.__dict__.items() - if not k.startswith("_") - }, - ) - - # Get the tokenized sample - if len(tokenized) > 0: - tokenized_sample = tokenized[0] - - # Add to buffer - buffer.append(tokenized_sample) - - # When buffer is full, pack and yield - if len(buffer) >= self.pack_length: - packed_sample = self._pack_buffer(buffer) - if packed_sample: - yield packed_sample - buffer = [] - - except Exception as e: - LOG.warning(f"Failed to process sample: {e}") - continue - - # Process remaining buffer - if buffer: - packed_sample = self._pack_buffer(buffer) - if packed_sample: - yield packed_sample - - def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]: - """Pack a buffer of tokenized samples.""" - if not buffer: - return None - - try: - # Create dataset from buffer - temp_dataset = Dataset.from_list(buffer) - - # Add position_ids and process for packing - temp_dataset = process_pretraining_datasets_for_packing( - temp_dataset, - self.max_tokens, - skip_position_ids=not self.cfg.pretrain_multipack_attn, - drop_attention_mask=self.cfg.pretrain_multipack_attn, - ) - - # Use MultipackBatchSampler to create packed batches - sampler = MultipackBatchSampler( - sampler=RandomSampler(temp_dataset), - lengths=get_dataset_lengths(temp_dataset), - batch_size=1, - batch_max_len=self.max_tokens, - drop_last=False, - num_processes=1, - ) - - # Get packed data - for batch in sampler: - if batch and batch[0]: # Check if we have data - features = [] - for idx in batch[0]: # batch[0] contains the indices - sample = temp_dataset[idx] - if "labels" not in sample: - sample["labels"] = sample["input_ids"].copy() - features.append(sample) - - # Apply collator to create final packed sample - if features: - packed = self.collator(features) - return { - k: v.squeeze(0) if v.dim() > 1 else v - for k, v in packed.items() - } - - return None - - except Exception as e: - LOG.warning(f"Failed to pack buffer: {e}") - return None - - -def wrap_streaming_sft_dataset_simple( - dataset: IterableDataset, - tokenizer: PreTrainedTokenizerBase, - cfg, - dataset_config, - d_base_type: str, - d_prompt_style: str | None, - processor: Any | None, - max_tokens: int = 2048, - buffer_size: int = 10_000, -) -> IterableDataset: - """ - Wrap a streaming SFT dataset with simple multipack batching. - - This approach processes samples in small groups rather than large batches, - avoiding the repeated processing issue. - """ - - # Apply shuffling if configured - if cfg.shuffle_merged_datasets: - LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}") - dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size) - - # Create the streaming multipack dataset - multipack_dataset = StreamingMultipackDataset( - dataset, - tokenizer, - cfg, - dataset_config, - d_base_type, - d_prompt_style, - processor, - max_tokens, - pack_length=max(1, max_tokens // 512), # Estimate sequences per pack - ) - - # Convert back to IterableDataset - class IterableWrapper: - def __init__(self, dataset): - self.dataset = dataset - - def __iter__(self): - return iter(self.dataset) - - wrapped = IterableWrapper(multipack_dataset) - - # Set micro_batch_size to 1 since sequences are already packed - cfg.micro_batch_size = 1 - - return wrapped