This commit is contained in:
Dan Saunders
2025-08-19 18:12:09 +00:00
parent 16ff01df85
commit b6431083be

View File

@@ -43,13 +43,11 @@ class TokenizedPromptDataset(Dataset):
)
def process(self, dataset):
# Handle both regular Dataset and IterableDataset
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
if hasattr(dataset, "features") and dataset.features:
features = dataset.features.keys()
else:
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
@@ -70,20 +68,20 @@ class TokenizedPromptDataset(Dataset):
**filter_kwargs,
)
map_kwargs_final = {
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if features is not None:
map_kwargs_final["remove_columns"] = features
map_kwargs_final["num_proc"] = self.process_count
map_kwargs_final["keep_in_memory"] = self.keep_in_memory
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
**map_kwargs_final,
**map_kwargs,
)