remove unused

This commit is contained in:
Dan Saunders
2025-08-25 15:46:25 +00:00
parent 3a35076513
commit 2e2302aae3
4 changed files with 0 additions and 543 deletions

View File

@@ -1,61 +0,0 @@
# Example configuration for streaming pretraining
# This demonstrates how to pretrain on large datasets that don't fit in memory
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
# Required: max_steps for streaming pretraining
max_steps: 10000
# Pretraining dataset configuration
# These are automatically streamed
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
# Optional: skip N samples (useful for resuming)
# skip: 1000000
# Can also use multiple pretraining datasets
# pretraining_dataset:
# - path: allenai/c4
# name: en
# type: pretrain
# - path: HuggingFaceFW/fineweb
# type: pretrain
val_set_size: 0.0
# Sequence and packing configuration
sequence_len: 2048
sample_packing: true
pretrain_multipack_attn: true
pretrain_multipack_buffer_size: 10000 # Buffer size for multipack batching
# Training hyperparameters
gradient_accumulation_steps: 8
micro_batch_size: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 3e-4
# Memory optimizations
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true
# Checkpointing and logging
output_dir: ./outputs/pretrain-streaming
logging_steps: 10
save_steps: 500
save_total_limit: 3 # Keep only last 3 checkpoints
# Warmup
warmup_ratio: 0.1
# Optional: enable wandb for monitoring
# wandb_project: streaming-pretrain
# wandb_entity: your-entity
# wandb_name: c4-pretrain

View File

@@ -1,52 +0,0 @@
# Example configuration for streaming SFT training
# This enables training on datasets larger than memory by streaming them from HuggingFace Hub
base_model: HuggingFaceTB/SmolLM2-135M
# Enable streaming mode for datasets
streaming: true
# When using streaming, max_steps is required
max_steps: 3 # Just test a few steps
# Training datasets - these will be streamed
# datasets:
# - path: tatsu-lab/alpaca
# type: alpaca
# split: train
pretraining_dataset:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Dataset configuration
sequence_len: 2048
sample_packing: true # Enable multipack batching
pretrain_multipack_attn: true # Enable multipack attention masking
pretrain_multipack_buffer_size: 1000 # Buffer size for packing (smaller for streaming SFT)
special_tokens:
pad_token: <|endoftext|>
# Training hyperparameters
gradient_accumulation_steps: 4
micro_batch_size: 1 # Always 1 for multipack - sequences are packed into single samples
num_epochs: 1 # With streaming, typically use max_steps instead
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
# Enable efficient training
bf16: auto
tf32: false
gradient_checkpointing: true
flash_attention: true # Enable flash attention with multipack
# Logging and checkpointing
logging_steps: 10
eval_steps: 100
save_steps: 200
output_dir: ./outputs/streaming-model
# Warmup
warmup_steps: 100

View File

@@ -1,219 +0,0 @@
"""Streaming dataset with multipack support for SFT."""
import functools
from collections import defaultdict
from typing import Any, Dict, List, Optional
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
def encode_packed_sft_streaming(
collate_fn,
ds_wrapper_fn,
examples: Dict[str, List],
dataset_config,
tokenizer,
cfg,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
"""
Encode and pack streaming SFT data similar to how pretraining does it.
This function:
1. Tokenizes the examples using the dataset wrapper
2. Adds position_ids for each sequence
3. Uses MultipackBatchSampler to efficiently pack sequences
4. Applies the collator to handle attention masks properly
Args:
collate_fn: Collator function for handling batches
ds_wrapper_fn: Function to get the dataset wrapper for tokenization
examples: Dict of lists containing the raw examples
dataset_config: Configuration for the dataset
tokenizer: Tokenizer to use
cfg: Main configuration
d_base_type: Dataset base type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_seq_length: Maximum sequence length
batch_size: Batch size for packing
multipack_attn: Whether to use multipack attention
Returns:
Dict of packed and processed data ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Convert examples to Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
train_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Process for packing - add position_ids and filter long sequences
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
drop_attention_mask=multipack_attn,
)
# Use MultipackBatchSampler to create efficient packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1, # We pack multiple sequences into one "batch"
batch_max_len=batch_size * max_seq_length, # Total tokens in packed batch
drop_last=True,
num_processes=1, # Single process for streaming
)
# Collect packed data
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
# Clean up unnecessary fields
for field in ["num_truncated_tokens", "overflow_to_sample_mapping"]:
if field in features:
del features[field]
# Ensure labels exist
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
# Apply collator to handle padding and attention masks
collated_features = collate_fn(features)
# Collect features
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data
def wrap_streaming_sft_dataset_with_packing(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and multipack batching.
This creates properly packed batches with:
- Multiple sequences concatenated together
- Position IDs that reset for each sequence
- Attention masks that prevent cross-attention between sequences
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for streaming/shuffling
Returns:
Wrapped streaming dataset with multipack batching
"""
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Get column names from first sample
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Create the collator for multipack
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Create the encoding function
# batch_size here refers to how many sequences to pack together to fill max_tokens
# The actual batching happens at the DataLoader level with micro_batch_size=1
pack_batch_size = max(
1, max_tokens // 512
) # Estimate based on typical sequence lengths
encode_fn = functools.partial(
encode_packed_sft_streaming,
collate_fn,
None, # ds_wrapper_fn will be created inside
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
processor=processor,
max_seq_length=max_tokens,
batch_size=pack_batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# Map the encoding function over the streaming dataset
# This will process data in batches and apply packing
dataset = dataset.map(
encode_fn,
batched=True,
batch_size=buffer_size, # Process large batches for efficiency
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
# IMPORTANT: Set micro_batch_size to 1 since we've already packed
# This prevents the trainer from trying to batch our packed sequences
cfg.micro_batch_size = 1
return dataset

View File

@@ -1,211 +0,0 @@
"""Simple streaming SFT with multipack support."""
import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
class StreamingMultipackDataset:
"""Dataset that handles streaming with multipack on-the-fly."""
def __init__(
self,
base_dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
pack_length: int = 4, # How many sequences to collect before packing
):
self.base_dataset = base_dataset
self.tokenizer = tokenizer
self.cfg = cfg
self.max_tokens = max_tokens
self.pack_length = pack_length
# Create the dataset wrapper once
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Create a dummy dataset to get the wrapper
dummy_data = {"text": ["dummy"], "instruction": ["dummy"], "output": ["dummy"]}
dummy_dataset = Dataset.from_dict(dummy_data)
self.dataset_wrapper, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dummy_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Create collator for packing
self.collator = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
def __iter__(self):
"""Iterator that yields packed samples."""
buffer = []
for sample in self.base_dataset:
# Convert single sample to dataset for processing
temp_dataset = Dataset.from_dict({k: [v] for k, v in sample.items()})
# Tokenize using the dataset wrapper
try:
tokenized = self.dataset_wrapper.__class__(
temp_dataset,
**{
k: v
for k, v in self.dataset_wrapper.__dict__.items()
if not k.startswith("_")
},
)
# Get the tokenized sample
if len(tokenized) > 0:
tokenized_sample = tokenized[0]
# Add to buffer
buffer.append(tokenized_sample)
# When buffer is full, pack and yield
if len(buffer) >= self.pack_length:
packed_sample = self._pack_buffer(buffer)
if packed_sample:
yield packed_sample
buffer = []
except Exception as e:
LOG.warning(f"Failed to process sample: {e}")
continue
# Process remaining buffer
if buffer:
packed_sample = self._pack_buffer(buffer)
if packed_sample:
yield packed_sample
def _pack_buffer(self, buffer: List[Dict]) -> Optional[Dict]:
"""Pack a buffer of tokenized samples."""
if not buffer:
return None
try:
# Create dataset from buffer
temp_dataset = Dataset.from_list(buffer)
# Add position_ids and process for packing
temp_dataset = process_pretraining_datasets_for_packing(
temp_dataset,
self.max_tokens,
skip_position_ids=not self.cfg.pretrain_multipack_attn,
drop_attention_mask=self.cfg.pretrain_multipack_attn,
)
# Use MultipackBatchSampler to create packed batches
sampler = MultipackBatchSampler(
sampler=RandomSampler(temp_dataset),
lengths=get_dataset_lengths(temp_dataset),
batch_size=1,
batch_max_len=self.max_tokens,
drop_last=False,
num_processes=1,
)
# Get packed data
for batch in sampler:
if batch and batch[0]: # Check if we have data
features = []
for idx in batch[0]: # batch[0] contains the indices
sample = temp_dataset[idx]
if "labels" not in sample:
sample["labels"] = sample["input_ids"].copy()
features.append(sample)
# Apply collator to create final packed sample
if features:
packed = self.collator(features)
return {
k: v.squeeze(0) if v.dim() > 1 else v
for k, v in packed.items()
}
return None
except Exception as e:
LOG.warning(f"Failed to pack buffer: {e}")
return None
def wrap_streaming_sft_dataset_simple(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with simple multipack batching.
This approach processes samples in small groups rather than large batches,
avoiding the repeated processing issue.
"""
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Create the streaming multipack dataset
multipack_dataset = StreamingMultipackDataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens,
pack_length=max(1, max_tokens // 512), # Estimate sequences per pack
)
# Convert back to IterableDataset
class IterableWrapper:
def __init__(self, dataset):
self.dataset = dataset
def __iter__(self):
return iter(self.dataset)
wrapped = IterableWrapper(multipack_dataset)
# Set micro_batch_size to 1 since sequences are already packed
cfg.micro_batch_size = 1
return wrapped