bugfix for sample packing
This commit is contained in:
@@ -94,19 +94,10 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
def peek_and_get_columns():
|
||||
# Create a fresh iterator just for peeking
|
||||
temp_iter = iter(dataset)
|
||||
first_example = next(temp_iter)
|
||||
return list(first_example.keys())
|
||||
|
||||
original_columns = peek_and_get_columns()
|
||||
|
||||
# Map the dataset and remove original columns
|
||||
# This ensures only tokenized columns remain
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=original_columns,
|
||||
remove_columns=list(dataset.features.keys()),
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from datasets import IterableDataset, disable_caching, enable_caching
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
|
||||
"""
|
||||
Create a filtered IterableDataset that works around a HuggingFace datasets
|
||||
limitation.
|
||||
"""
|
||||
|
||||
def filtered_generator():
|
||||
"""Generator that yields only samples that pass the filter function."""
|
||||
if batched:
|
||||
batch = []
|
||||
batch_size = 1000 # Process in batches of 1000
|
||||
|
||||
for sample in dataset:
|
||||
batch.append(sample)
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
# Create a batch dict from list of samples
|
||||
batch_dict = {}
|
||||
for key in batch[0].keys():
|
||||
batch_dict[key] = [sample[key] for sample in batch]
|
||||
|
||||
# Apply filter function to batch
|
||||
keep_mask = filter_fn(batch_dict)
|
||||
|
||||
# Yield samples that should be kept
|
||||
for i, keep in enumerate(keep_mask):
|
||||
if keep:
|
||||
yield batch[i]
|
||||
|
||||
batch = []
|
||||
|
||||
# Process remaining samples in batch
|
||||
if batch:
|
||||
batch_dict = {}
|
||||
for key in batch[0].keys():
|
||||
batch_dict[key] = [sample[key] for sample in batch]
|
||||
|
||||
keep_mask = filter_fn(batch_dict)
|
||||
|
||||
for i, keep in enumerate(keep_mask):
|
||||
if keep:
|
||||
yield batch[i]
|
||||
else:
|
||||
# For non-batched filtering, apply filter to each sample individually
|
||||
for sample in dataset:
|
||||
if filter_fn(sample):
|
||||
yield sample
|
||||
|
||||
# Create new IterableDataset from the filtered generator
|
||||
filtered_dataset = IterableDataset.from_generator(filtered_generator)
|
||||
|
||||
# Preserve the original features if they exist
|
||||
# pylint:disable=protected-access
|
||||
if hasattr(dataset, "_info") and dataset._info.features is not None:
|
||||
filtered_dataset._info.features = dataset._info.features
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def weighted_cross_entropy(
|
||||
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||
@@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long_kwargs = {}
|
||||
if filter_map_kwargs:
|
||||
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
|
||||
# For IterableDatasets, always use custom filtering to avoid features issues
|
||||
if isinstance(train_dataset, IterableDataset):
|
||||
# IterableDatasets often have None features after transformations,
|
||||
# so we use our custom filter implementation that doesn't rely on features
|
||||
train_dataset = _create_filtered_iterable_dataset(
|
||||
train_dataset, drop_no_trainable_tokens, batched=True
|
||||
)
|
||||
else:
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
batched=True,
|
||||
**filter_map_kwargs,
|
||||
**drop_long_kwargs,
|
||||
)
|
||||
if prior_len:
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
|
||||
Reference in New Issue
Block a user