This commit is contained in:
Dan Saunders
2025-08-21 17:35:24 +00:00
parent 68bb70bbae
commit 0caa24eab0

View File

@@ -2,10 +2,8 @@
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset, for example:
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
later on to pad the datasets.
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
from typing import Any
@@ -47,8 +45,6 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization."""
# For IterableDataset, we can't access features up front. Anyways, we don't care
# to remove unused columns from streaming datasets.
features = None
if not isinstance(dataset, IterableDataset):
features = dataset.features.keys()
@@ -97,8 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
# Peek at the first example to get original column names
first_example = next(iter(dataset))
original_columns = list(first_example.keys())
# Map the dataset and remove original columns
# This ensures only tokenized columns remain
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=original_columns,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)