update table for rwkv4 support, fix process count for dataset (#822)
This commit is contained in:
@@ -74,6 +74,7 @@ Features:
|
|||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset
|
||||||
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
self,
|
self,
|
||||||
prompt_tokenizer: PromptTokenizingStrategy,
|
prompt_tokenizer: PromptTokenizingStrategy,
|
||||||
dataset: IterableDataset,
|
dataset: IterableDataset,
|
||||||
|
process_count: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.prompt_tokenizer = prompt_tokenizer
|
self.prompt_tokenizer = prompt_tokenizer
|
||||||
|
self.process_count = process_count
|
||||||
super().__init__(self.process(dataset).data, **kwargs)
|
super().__init__(self.process(dataset).data, **kwargs)
|
||||||
|
|
||||||
def process(self, dataset):
|
def process(self, dataset):
|
||||||
features = dataset.features.keys()
|
features = dataset.features.keys()
|
||||||
num_proc = min(64, os.cpu_count())
|
num_proc = (
|
||||||
|
min(64, self.process_count)
|
||||||
|
if self.process_count
|
||||||
|
else min(64, os.cpu_count())
|
||||||
|
)
|
||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
|
|||||||
@@ -482,10 +482,14 @@ def get_dataset_wrapper(
|
|||||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||||
)
|
)
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||||
dataset_prompter = UnsupportedPrompter()
|
dataset_prompter = UnsupportedPrompter()
|
||||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
dataset_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||||
@@ -494,7 +498,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "explainchoice":
|
elif d_base_type == "explainchoice":
|
||||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||||
@@ -504,7 +510,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "concisechoice":
|
elif d_base_type == "concisechoice":
|
||||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||||
@@ -514,7 +522,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "summarizetldr":
|
elif d_base_type == "summarizetldr":
|
||||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||||
@@ -524,7 +534,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "jeopardy":
|
elif d_base_type == "jeopardy":
|
||||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||||
@@ -534,7 +546,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "oasst":
|
elif d_base_type == "oasst":
|
||||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||||
@@ -544,7 +558,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "gpteacher":
|
elif d_base_type == "gpteacher":
|
||||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||||
@@ -554,7 +570,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
elif d_base_type == "reflection":
|
elif d_base_type == "reflection":
|
||||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||||
@@ -564,7 +582,9 @@ def get_dataset_wrapper(
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
ds_wrapper = TokenizedPromptDataset(
|
||||||
|
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||||
|
)
|
||||||
dataset_wrapper = ds_wrapper
|
dataset_wrapper = ds_wrapper
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
|
|||||||
Reference in New Issue
Block a user