From 0caa24eab0ec83e02de35acb515f75009723b076 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 21 Aug 2025 17:35:24 +0000 Subject: [PATCH] comments --- src/axolotl/datasets.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index d70703d6b..5574d3d47 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -2,10 +2,8 @@ Module containing dataset functionality. We want this to be a wrapper for an existing dataset that we have loaded. Lets use the -concept of middlewares to wrap each dataset, for example: -ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])). -Let's check to ensure we don't truncate an item in the middle. We'll use the collators -later on to pad the datasets. +concept of middlewares to wrap each dataset. We'll use the collators later on to pad the +datasets. """ from typing import Any @@ -47,8 +45,6 @@ class TokenizedPromptDataset(Dataset): def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset: """Apply filtering and tokenization.""" - # For IterableDataset, we can't access features up front. Anyways, we don't care - # to remove unused columns from streaming datasets. features = None if not isinstance(dataset, IterableDataset): features = dataset.features.keys() @@ -97,8 +93,16 @@ def wrap_dataset_for_tokenized_prompt( map_kwargs = {} if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True + + # Peek at the first example to get original column names + first_example = next(iter(dataset)) + original_columns = list(first_example.keys()) + + # Map the dataset and remove original columns + # This ensures only tokenized columns remain return dataset.map( prompt_tokenizer.tokenize_prompt, + remove_columns=original_columns, **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)