bugfix for sample packing

This commit is contained in:
Dan Saunders
2025-08-22 04:33:48 +00:00
parent 49bd6ece4a
commit 53bbca2591
2 changed files with 75 additions and 17 deletions

View File

@@ -94,19 +94,10 @@ def wrap_dataset_for_tokenized_prompt(
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
def peek_and_get_columns():
# Create a fresh iterator just for peeking
temp_iter = iter(dataset)
first_example = next(temp_iter)
return list(first_example.keys())
original_columns = peek_and_get_columns()
# Map the dataset and remove original columns
# This ensures only tokenized columns remain
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=original_columns,
remove_columns=list(dataset.features.keys()),
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)

View File

@@ -10,7 +10,6 @@ from typing import List, Optional
import numpy as np
import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
"""
Create a filtered IterableDataset that works around a HuggingFace datasets
limitation.
"""
def filtered_generator():
"""Generator that yields only samples that pass the filter function."""
if batched:
batch = []
batch_size = 1000 # Process in batches of 1000
for sample in dataset:
batch.append(sample)
if len(batch) >= batch_size:
# Create a batch dict from list of samples
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
# Apply filter function to batch
keep_mask = filter_fn(batch_dict)
# Yield samples that should be kept
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
batch = []
# Process remaining samples in batch
if batch:
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
keep_mask = filter_fn(batch_dict)
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
else:
# For non-batched filtering, apply filter to each sample individually
for sample in dataset:
if filter_fn(sample):
yield sample
# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)
# Preserve the original features if they exist
# pylint:disable=protected-access
if hasattr(dataset, "_info") and dataset._info.features is not None:
filtered_dataset._info.features = dataset._info.features
return filtered_dataset
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
@@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
# For IterableDatasets, always use custom filtering to avoid features issues
if isinstance(train_dataset, IterableDataset):
# IterableDatasets often have None features after transformations,
# so we use our custom filter implementation that doesn't rely on features
train_dataset = _create_filtered_iterable_dataset(
train_dataset, drop_no_trainable_tokens, batched=True
)
else:
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped: