update table for rwkv4 support, fix process count for dataset (#822)

This commit is contained in:
Wing Lian
2023-11-04 23:45:44 -04:00
committed by GitHub
parent 6459ac7357
commit cdc71f73c8
3 changed files with 39 additions and 12 deletions

View File

@@ -74,6 +74,7 @@ Features:
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ | | gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ | | XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡

View File

@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import List from typing import List, Optional
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs) super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset): def process(self, dataset):
features = dataset.features.keys() features = dataset.features.keys()
num_proc = min(64, os.cpu_count()) num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {} map_kwargs = {}
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True

View File

@@ -482,10 +482,14 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict() "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
) )
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca": elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style) dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy( ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice": elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice": elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr": elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy": elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style) dataset_prompter = JeopardyPrompter(d_prompt_style)
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "oasst": elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style) dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher": elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style) dataset_prompter = GPTeacherPrompter(d_prompt_style)
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "reflection": elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
else: else:
suffix = "" suffix = ""