diff --git a/README.md b/README.md index c9575d659..bafec22a6 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,7 @@ Features: | gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ | | XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ | | phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | +| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ | ## Quickstart ⚡ diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 32b2e0cc2..d5362643f 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -2,7 +2,7 @@ import logging import os -from typing import List +from typing import List, Optional import torch from datasets import Dataset, IterableDataset @@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset): self, prompt_tokenizer: PromptTokenizingStrategy, dataset: IterableDataset, + process_count: Optional[int] = None, **kwargs, ): self.prompt_tokenizer = prompt_tokenizer + self.process_count = process_count super().__init__(self.process(dataset).data, **kwargs) def process(self, dataset): features = dataset.features.keys() - num_proc = min(64, os.cpu_count()) + num_proc = ( + min(64, self.process_count) + if self.process_count + else min(64, os.cpu_count()) + ) map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 124b607b3..697c26baa 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -482,10 +482,14 @@ def get_dataset_wrapper( "user_defined", tokenizer, cfg, config_dataset.type.to_dict() ) dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + dataset_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) elif d_base_type == "alpaca": dataset_prompter = AlpacaPrompter(d_prompt_style) ds_strategy = AlpacaPromptTokenizingStrategy( @@ -494,7 +498,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "explainchoice": dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) @@ -504,7 +510,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "concisechoice": dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) @@ -514,7 +522,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "summarizetldr": dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) @@ -524,7 +534,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "jeopardy": dataset_prompter = JeopardyPrompter(d_prompt_style) @@ -534,7 +546,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "oasst": dataset_prompter = AlpacaPrompter(d_prompt_style) @@ -544,7 +558,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "gpteacher": dataset_prompter = GPTeacherPrompter(d_prompt_style) @@ -554,7 +570,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper elif d_base_type == "reflection": dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) @@ -564,7 +582,9 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) + ds_wrapper = TokenizedPromptDataset( + ds_strategy, dataset, process_count=cfg.dataset_processes + ) dataset_wrapper = ds_wrapper else: suffix = ""