comments
This commit is contained in:
@@ -2,10 +2,8 @@
|
||||
Module containing dataset functionality.
|
||||
|
||||
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
||||
concept of middlewares to wrap each dataset, for example:
|
||||
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
|
||||
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
|
||||
later on to pad the datasets.
|
||||
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
|
||||
datasets.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
@@ -47,8 +45,6 @@ class TokenizedPromptDataset(Dataset):
|
||||
|
||||
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
|
||||
"""Apply filtering and tokenization."""
|
||||
# For IterableDataset, we can't access features up front. Anyways, we don't care
|
||||
# to remove unused columns from streaming datasets.
|
||||
features = None
|
||||
if not isinstance(dataset, IterableDataset):
|
||||
features = dataset.features.keys()
|
||||
@@ -97,8 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
|
||||
# Peek at the first example to get original column names
|
||||
first_example = next(iter(dataset))
|
||||
original_columns = list(first_example.keys())
|
||||
|
||||
# Map the dataset and remove original columns
|
||||
# This ensures only tokenized columns remain
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=original_columns,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user