update table for rwkv4 support, fix process count for dataset (#822)

This commit is contained in:
Wing Lian
2023-11-04 23:45:44 -04:00
committed by GitHub
parent 6459ac7357
commit cdc71f73c8
3 changed files with 39 additions and 12 deletions

View File

@@ -2,7 +2,7 @@
import logging
import os
from typing import List
from typing import List, Optional
import torch
from datasets import Dataset, IterableDataset
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, os.cpu_count())
num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True

View File

@@ -482,10 +482,14 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
else:
suffix = ""