support streaming for processing sft datasts?
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
@@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
def wrap_dataset_for_tokenized_prompt(
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(dataset, IterableDataset):
|
||||
map_kwargs = {}
|
||||
if prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
features = dataset.features.keys()
|
||||
return dataset.map(
|
||||
prompt_tokenizer.tokenize_prompt,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||
|
||||
|
||||
# TODO this isn't the best since it can't interleave datasets
|
||||
class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ from datasets import (
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.datasets import wrap_dataset_for_tokenized_prompt
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -444,7 +444,7 @@ def get_dataset_wrapper(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -457,7 +457,7 @@ def get_dataset_wrapper(
|
||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||
):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -469,7 +469,7 @@ def get_dataset_wrapper(
|
||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||
else:
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -482,7 +482,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -496,7 +496,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -510,7 +510,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -524,7 +524,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -538,7 +538,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -552,7 +552,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -566,7 +566,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
@@ -580,7 +580,7 @@ def get_dataset_wrapper(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||
ds_strategy,
|
||||
dataset,
|
||||
**ds_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user