support streaming for processing sft datasts?
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_dataset_for_tokenized_prompt(
|
||||||
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
map_kwargs = {}
|
||||||
|
if prompt_tokenizer.supports_batched:
|
||||||
|
map_kwargs["batched"] = True
|
||||||
|
features = dataset.features.keys()
|
||||||
|
return dataset.map(
|
||||||
|
prompt_tokenizer.tokenize_prompt,
|
||||||
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# TODO this isn't the best since it can't interleave datasets
|
# TODO this isn't the best since it can't interleave datasets
|
||||||
class ConstantLengthDataset(IterableDataset):
|
class ConstantLengthDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from datasets import (
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import wrap_dataset_for_tokenized_prompt
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
@@ -444,7 +444,7 @@ def get_dataset_wrapper(
|
|||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -457,7 +457,7 @@ def get_dataset_wrapper(
|
|||||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||||
):
|
):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -469,7 +469,7 @@ def get_dataset_wrapper(
|
|||||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||||
else:
|
else:
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -482,7 +482,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -496,7 +496,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -510,7 +510,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -524,7 +524,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -538,7 +538,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -552,7 +552,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -566,7 +566,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -580,7 +580,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user