From 7cd0a317cbfd9ebb705256ad96617cd0b93fdc83 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 1 Jan 2025 09:11:14 -0500 Subject: [PATCH] support streaming for processing sft datasts? --- src/axolotl/datasets.py | 20 +++++++++++++++++++- src/axolotl/utils/data/sft.py | 24 ++++++++++++------------ 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 005cc41bc..13ae13d3b 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -2,7 +2,7 @@ import logging import os -from typing import List, Optional +from typing import List, Optional, Union import torch from datasets import Dataset, IterableDataset @@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset): ) +def wrap_dataset_for_tokenized_prompt( + prompt_tokenizer: PromptTokenizingStrategy, + dataset: Union[Dataset, IterableDataset], + **kwargs, +): + if isinstance(dataset, IterableDataset): + map_kwargs = {} + if prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + features = dataset.features.keys() + return dataset.map( + prompt_tokenizer.tokenize_prompt, + remove_columns=features, + **map_kwargs, + ) + return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) + + # TODO this isn't the best since it can't interleave datasets class ConstantLengthDataset(IterableDataset): """ diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 47ba5e88b..98a2a1af8 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -15,7 +15,7 @@ from datasets import ( from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import TokenizedPromptDataset +from axolotl.datasets import wrap_dataset_for_tokenized_prompt from axolotl.prompt_strategies import load from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load from axolotl.prompt_tokenizers import ( @@ -450,7 +450,7 @@ def get_dataset_wrapper( "user_defined", tokenizer, cfg, config_dataset.type.to_dict() ) dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -463,7 +463,7 @@ def get_dataset_wrapper( config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset ): dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -475,7 +475,7 @@ def get_dataset_wrapper( dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) else: dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -488,7 +488,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -502,7 +502,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -516,7 +516,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -530,7 +530,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -544,7 +544,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -558,7 +558,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -572,7 +572,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -586,7 +586,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs,