support streaming for processing sft datasts?

This commit is contained in:
Wing Lian
2025-01-01 09:11:14 -05:00
parent 9ed455ef8c
commit 6bbe3ac641
2 changed files with 31 additions and 13 deletions

View File

@@ -2,7 +2,7 @@
import logging
import os
from typing import List, Optional
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
@@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
)
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""

View File

@@ -15,7 +15,7 @@ from datasets import (
from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.datasets import wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
@@ -444,7 +444,7 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -457,7 +457,7 @@ def get_dataset_wrapper(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -469,7 +469,7 @@ def get_dataset_wrapper(
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -482,7 +482,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -496,7 +496,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -510,7 +510,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -524,7 +524,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -538,7 +538,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -552,7 +552,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -566,7 +566,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -580,7 +580,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,