diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 2d20de4d3..e3bbc8cf7 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -43,13 +43,11 @@ class TokenizedPromptDataset(Dataset): ) def process(self, dataset): - # Handle both regular Dataset and IterableDataset + # For IterableDataset, we can't access features upfront + # We'll need to infer from the first batch + features = None if hasattr(dataset, "features") and dataset.features: features = dataset.features.keys() - else: - # For IterableDataset, we can't access features upfront - # We'll need to infer from the first batch - features = None map_kwargs = {} if self.prompt_tokenizer.supports_batched: @@ -70,20 +68,20 @@ class TokenizedPromptDataset(Dataset): **filter_kwargs, ) - map_kwargs_final = { + map_kwargs = { **map_kwargs, "desc": "Tokenizing Prompts", } # Only add remove_columns for regular datasets if features is not None: - map_kwargs_final["remove_columns"] = features - map_kwargs_final["num_proc"] = self.process_count - map_kwargs_final["keep_in_memory"] = self.keep_in_memory + map_kwargs["remove_columns"] = features + map_kwargs["num_proc"] = self.process_count + map_kwargs["keep_in_memory"] = self.keep_in_memory return dataset.map( self.prompt_tokenizer.tokenize_prompt, - **map_kwargs_final, + **map_kwargs, )