This commit is contained in:
Dan Saunders
2025-08-19 18:12:09 +00:00
parent 16ff01df85
commit b6431083be

View File

@@ -43,13 +43,11 @@ class TokenizedPromptDataset(Dataset):
) )
def process(self, dataset): def process(self, dataset):
# Handle both regular Dataset and IterableDataset # For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
if hasattr(dataset, "features") and dataset.features: if hasattr(dataset, "features") and dataset.features:
features = dataset.features.keys() features = dataset.features.keys()
else:
# For IterableDataset, we can't access features upfront
# We'll need to infer from the first batch
features = None
map_kwargs = {} map_kwargs = {}
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
@@ -70,20 +68,20 @@ class TokenizedPromptDataset(Dataset):
**filter_kwargs, **filter_kwargs,
) )
map_kwargs_final = { map_kwargs = {
**map_kwargs, **map_kwargs,
"desc": "Tokenizing Prompts", "desc": "Tokenizing Prompts",
} }
# Only add remove_columns for regular datasets # Only add remove_columns for regular datasets
if features is not None: if features is not None:
map_kwargs_final["remove_columns"] = features map_kwargs["remove_columns"] = features
map_kwargs_final["num_proc"] = self.process_count map_kwargs["num_proc"] = self.process_count
map_kwargs_final["keep_in_memory"] = self.keep_in_memory map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map( return dataset.map(
self.prompt_tokenizer.tokenize_prompt, self.prompt_tokenizer.tokenize_prompt,
**map_kwargs_final, **map_kwargs,
) )