support streaming for processing sft datasts?

This commit is contained in:
Wing Lian
2025-01-01 09:11:14 -05:00
parent 9ed455ef8c
commit 6bbe3ac641
2 changed files with 31 additions and 13 deletions

View File

@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import List, Optional from typing import List, Optional, Union
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
) )
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets # TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset): class ConstantLengthDataset(IterableDataset):
""" """

View File

@@ -15,7 +15,7 @@ from datasets import (
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
@@ -444,7 +444,7 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict() "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
) )
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset( dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -457,7 +457,7 @@ def get_dataset_wrapper(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
): ):
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset( dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -469,7 +469,7 @@ def get_dataset_wrapper(
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else: else:
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset( dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -482,7 +482,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -496,7 +496,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -510,7 +510,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -524,7 +524,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -538,7 +538,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -552,7 +552,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -566,7 +566,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,
@@ -580,7 +580,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset( ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy, ds_strategy,
dataset, dataset,
**ds_kwargs, **ds_kwargs,