nit
This commit is contained in:
@@ -43,13 +43,11 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
# Handle both regular Dataset and IterableDataset
|
# For IterableDataset, we can't access features upfront
|
||||||
|
# We'll need to infer from the first batch
|
||||||
|
features = None
|
||||||
if hasattr(dataset, "features") and dataset.features:
|
if hasattr(dataset, "features") and dataset.features:
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
else:
|
|
||||||
# For IterableDataset, we can't access features upfront
|
|
||||||
# We'll need to infer from the first batch
|
|
||||||
features = None
|
|
||||||
|
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
@@ -70,20 +68,20 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
**filter_kwargs,
|
**filter_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
map_kwargs_final = {
|
map_kwargs = {
|
||||||
**map_kwargs,
|
**map_kwargs,
|
||||||
"desc": "Tokenizing Prompts",
|
"desc": "Tokenizing Prompts",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Only add remove_columns for regular datasets
|
# Only add remove_columns for regular datasets
|
||||||
if features is not None:
|
if features is not None:
|
||||||
map_kwargs_final["remove_columns"] = features
|
map_kwargs["remove_columns"] = features
|
||||||
map_kwargs_final["num_proc"] = self.process_count
|
map_kwargs["num_proc"] = self.process_count
|
||||||
map_kwargs_final["keep_in_memory"] = self.keep_in_memory
|
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
||||||
|
|
||||||
return dataset.map(
|
return dataset.map(
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
**map_kwargs_final,
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user