diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index de847bcd8..44de3c975 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -94,19 +94,10 @@ def wrap_dataset_for_tokenized_prompt( if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - def peek_and_get_columns(): - # Create a fresh iterator just for peeking - temp_iter = iter(dataset) - first_example = next(temp_iter) - return list(first_example.keys()) - - original_columns = peek_and_get_columns() - # Map the dataset and remove original columns - # This ensures only tokenized columns remain return dataset.map( prompt_tokenizer.tokenize_prompt, - remove_columns=original_columns, + remove_columns=list(dataset.features.keys()), **map_kwargs, ) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index faadc93bc..7f486825f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -10,7 +10,6 @@ from typing import List, Optional import numpy as np import torch -import torch.cuda from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available @@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = get_logger(__name__) +def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False): + """ + Create a filtered IterableDataset that works around a HuggingFace datasets + limitation. + """ + + def filtered_generator(): + """Generator that yields only samples that pass the filter function.""" + if batched: + batch = [] + batch_size = 1000 # Process in batches of 1000 + + for sample in dataset: + batch.append(sample) + + if len(batch) >= batch_size: + # Create a batch dict from list of samples + batch_dict = {} + for key in batch[0].keys(): + batch_dict[key] = [sample[key] for sample in batch] + + # Apply filter function to batch + keep_mask = filter_fn(batch_dict) + + # Yield samples that should be kept + for i, keep in enumerate(keep_mask): + if keep: + yield batch[i] + + batch = [] + + # Process remaining samples in batch + if batch: + batch_dict = {} + for key in batch[0].keys(): + batch_dict[key] = [sample[key] for sample in batch] + + keep_mask = filter_fn(batch_dict) + + for i, keep in enumerate(keep_mask): + if keep: + yield batch[i] + else: + # For non-batched filtering, apply filter to each sample individually + for sample in dataset: + if filter_fn(sample): + yield sample + + # Create new IterableDataset from the filtered generator + filtered_dataset = IterableDataset.from_generator(filtered_generator) + + # Preserve the original features if they exist + # pylint:disable=protected-access + if hasattr(dataset, "_info") and dataset._info.features is not None: + filtered_dataset._info.features = dataset._info.features + + return filtered_dataset + + @torch.jit.script def weighted_cross_entropy( logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor @@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): drop_long_kwargs = {} if filter_map_kwargs: drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" - train_dataset = train_dataset.filter( - drop_no_trainable_tokens, - batched=True, - **filter_map_kwargs, - **drop_long_kwargs, - ) + + # For IterableDatasets, always use custom filtering to avoid features issues + if isinstance(train_dataset, IterableDataset): + # IterableDatasets often have None features after transformations, + # so we use our custom filter implementation that doesn't rely on features + train_dataset = _create_filtered_iterable_dataset( + train_dataset, drop_no_trainable_tokens, batched=True + ) + else: + train_dataset = train_dataset.filter( + drop_no_trainable_tokens, + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, + ) if prior_len: dropped = prior_len - len(train_dataset) if dropped: