support streaming for processing sft datasts?
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -68,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_dataset_for_tokenized_prompt(
|
||||||
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if isinstance(dataset, IterableDataset):
|
||||||
|
map_kwargs = {}
|
||||||
|
if prompt_tokenizer.supports_batched:
|
||||||
|
map_kwargs["batched"] = True
|
||||||
|
features = dataset.features.keys()
|
||||||
|
return dataset.map(
|
||||||
|
prompt_tokenizer.tokenize_prompt,
|
||||||
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# TODO this isn't the best since it can't interleave datasets
|
# TODO this isn't the best since it can't interleave datasets
|
||||||
class ConstantLengthDataset(IterableDataset):
|
class ConstantLengthDataset(IterableDataset):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from datasets import (
|
|||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import wrap_dataset_for_tokenized_prompt
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
@@ -450,7 +450,7 @@ def get_dataset_wrapper(
|
|||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -463,7 +463,7 @@ def get_dataset_wrapper(
|
|||||||
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
|
||||||
):
|
):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -475,7 +475,7 @@ def get_dataset_wrapper(
|
|||||||
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
|
||||||
else:
|
else:
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -488,7 +488,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -502,7 +502,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -516,7 +516,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -530,7 +530,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -544,7 +544,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -558,7 +558,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -572,7 +572,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
@@ -586,7 +586,7 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(
|
ds_wrapper = wrap_dataset_for_tokenized_prompt(
|
||||||
ds_strategy,
|
ds_strategy,
|
||||||
dataset,
|
dataset,
|
||||||
**ds_kwargs,
|
**ds_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user