nit
This commit is contained in:
@@ -43,13 +43,11 @@ class TokenizedPromptDataset(Dataset):
|
||||
)
|
||||
|
||||
def process(self, dataset):
|
||||
# Handle both regular Dataset and IterableDataset
|
||||
# For IterableDataset, we can't access features upfront
|
||||
# We'll need to infer from the first batch
|
||||
features = None
|
||||
if hasattr(dataset, "features") and dataset.features:
|
||||
features = dataset.features.keys()
|
||||
else:
|
||||
# For IterableDataset, we can't access features upfront
|
||||
# We'll need to infer from the first batch
|
||||
features = None
|
||||
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
@@ -70,20 +68,20 @@ class TokenizedPromptDataset(Dataset):
|
||||
**filter_kwargs,
|
||||
)
|
||||
|
||||
map_kwargs_final = {
|
||||
map_kwargs = {
|
||||
**map_kwargs,
|
||||
"desc": "Tokenizing Prompts",
|
||||
}
|
||||
|
||||
# Only add remove_columns for regular datasets
|
||||
if features is not None:
|
||||
map_kwargs_final["remove_columns"] = features
|
||||
map_kwargs_final["num_proc"] = self.process_count
|
||||
map_kwargs_final["keep_in_memory"] = self.keep_in_memory
|
||||
map_kwargs["remove_columns"] = features
|
||||
map_kwargs["num_proc"] = self.process_count
|
||||
map_kwargs["keep_in_memory"] = self.keep_in_memory
|
||||
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
**map_kwargs_final,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user