This commit is contained in:
Dan Saunders
2025-08-21 17:35:24 +00:00
parent 68bb70bbae
commit 0caa24eab0

View File

@@ -2,10 +2,8 @@
Module containing dataset functionality. Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset, for example: concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])). datasets.
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
later on to pad the datasets.
""" """
from typing import Any from typing import Any
@@ -47,8 +45,6 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset: def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization.""" """Apply filtering and tokenization."""
# For IterableDataset, we can't access features up front. Anyways, we don't care
# to remove unused columns from streaming datasets.
features = None features = None
if not isinstance(dataset, IterableDataset): if not isinstance(dataset, IterableDataset):
features = dataset.features.keys() features = dataset.features.keys()
@@ -97,8 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {} map_kwargs = {}
if prompt_tokenizer.supports_batched: if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
# Peek at the first example to get original column names
first_example = next(iter(dataset))
original_columns = list(first_example.keys())
# Map the dataset and remove original columns
# This ensures only tokenized columns remain
return dataset.map( return dataset.map(
prompt_tokenizer.tokenize_prompt, prompt_tokenizer.tokenize_prompt,
remove_columns=original_columns,
**map_kwargs, **map_kwargs,
) )
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)