comments
This commit is contained in:
@@ -2,10 +2,8 @@
|
|||||||
Module containing dataset functionality.
|
Module containing dataset functionality.
|
||||||
|
|
||||||
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
|
||||||
concept of middlewares to wrap each dataset, for example:
|
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
|
||||||
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])).
|
datasets.
|
||||||
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
|
|
||||||
later on to pad the datasets.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -47,8 +45,6 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
|
|
||||||
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
|
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
|
||||||
"""Apply filtering and tokenization."""
|
"""Apply filtering and tokenization."""
|
||||||
# For IterableDataset, we can't access features up front. Anyways, we don't care
|
|
||||||
# to remove unused columns from streaming datasets.
|
|
||||||
features = None
|
features = None
|
||||||
if not isinstance(dataset, IterableDataset):
|
if not isinstance(dataset, IterableDataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
@@ -97,8 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
|
|||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if prompt_tokenizer.supports_batched:
|
if prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|
||||||
|
# Peek at the first example to get original column names
|
||||||
|
first_example = next(iter(dataset))
|
||||||
|
original_columns = list(first_example.keys())
|
||||||
|
|
||||||
|
# Map the dataset and remove original columns
|
||||||
|
# This ensures only tokenized columns remain
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
prompt_tokenizer.tokenize_prompt,
|
prompt_tokenizer.tokenize_prompt,
|
||||||
|
remove_columns=original_columns,
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user