From d2e7f27240868b7fab266f7c43707c0c73bef9eb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 20 Aug 2023 09:17:49 -0400 Subject: [PATCH] support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348) * support user defined prompters, pretokenized datasets in config, local parquet, local arrow files * fix user defined dataset types * fix for system prompts * fix tests * fix checks for parquet and arrow * aha moment that d.data_files isn't used * add documentation for ds_type to add support for parquet and arrow --- README.md | 1 + src/axolotl/prompt_strategies/__init__.py | 9 +- .../prompt_strategies/alpaca_w_system.py | 2 + src/axolotl/prompt_strategies/user_defined.py | 98 +++++++++++++++++++ src/axolotl/prompters.py | 18 ++-- src/axolotl/utils/data.py | 32 +++++- 6 files changed, 146 insertions(+), 14 deletions(-) create mode 100644 src/axolotl/prompt_strategies/user_defined.py diff --git a/README.md b/README.md index 21d45622d..bd0446f22 100644 --- a/README.md +++ b/README.md @@ -392,6 +392,7 @@ datasets: - path: vicgalle/alpaca-gpt4 # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] type: alpaca # format | format: (chat/instruct) | .load_ + ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file data_files: # path to source data files shards: # number of shards to split data into name: # name of dataset configuration to load diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 2f6af208c..e9e567953 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -2,8 +2,10 @@ import importlib +from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig -def load(strategy, tokenizer, cfg): + +def load(strategy, tokenizer, cfg, ds_cfg): try: load_fn = "load" if strategy.split(".")[-1].startswith("load_"): @@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg): strategy = ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") func = getattr(mod, load_fn) - return func(tokenizer, cfg) + load_kwargs = {} + if strategy == "user_defined": + load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) + return func(tokenizer, cfg, **load_kwargs) except Exception: # pylint: disable=broad-exception-caught return None diff --git a/src/axolotl/prompt_strategies/alpaca_w_system.py b/src/axolotl/prompt_strategies/alpaca_w_system.py index 4b5521d48..8c8cc0743 100644 --- a/src/axolotl/prompt_strategies/alpaca_w_system.py +++ b/src/axolotl/prompt_strategies/alpaca_w_system.py @@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter): Alpaca Style Prompter that uses system prompts from the dataset """ + system_format: str = "### System:\n{system}\n\n" + def build_prompt_w_system( self, system: str, diff --git a/src/axolotl/prompt_strategies/user_defined.py b/src/axolotl/prompt_strategies/user_defined.py new file mode 100644 index 000000000..e20e80c3a --- /dev/null +++ b/src/axolotl/prompt_strategies/user_defined.py @@ -0,0 +1,98 @@ +""" +User Defined prompts with configuration from the YML config +""" + +from dataclasses import dataclass +from functools import partial +from typing import Optional, Tuple + +from axolotl.prompt_strategies.alpaca_w_system import ( + InstructionWSystemPromptTokenizingStrategy, + SystemDataPrompter, +) + + +@dataclass +class UserDefinedDatasetConfig: + """ + dataclass configuration representing a userdefined dataset type + """ + + system_prompt: str = "" + field_system: str = "system" + field_instruction: str = "instruction" + field_input: str = "input" + field_output: str = "output" + format: str = "{instruction} {input} " + no_input_format: str = "{instruction} " + system_format: str = "{system}" + + def __getitem__(self, item): + return getattr(self, item) + + +class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy): + """ + Prompt Tokenization Strategy for user defined prompts + """ + + +def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): + if not ds_cfg: + raise ValueError("Missing dataset prompt configuration") + + system_prompt = "" + if ds_cfg.system_prompt: + system_prompt = ds_cfg.system_prompt + + def parse_instruction_fields( + field_instruction, + field_input, + field_output, + field_system, + system_prompt, + prompt, + ) -> Tuple[str, str, str, str]: + return ( + prompt[field_instruction], + prompt[field_input] if field_input in prompt else "", + prompt[field_output] if field_output in prompt else "", + prompt[field_system] if field_system in prompt else system_prompt, + ) + + turn_format = ds_cfg.format + turn_no_input_format = ds_cfg.no_input_format + system_format = ds_cfg.system_format + + class UserDefinedPrompter(SystemDataPrompter): + """ + Prompter for user defined prompts + """ + + def match_prompt_style(self): + self.turn_format = turn_format + self.turn_no_input_format = turn_no_input_format + self.system_format = system_format + + prompter = UserDefinedPrompter() + + strat = UserDefinedPromptTokenizationStrategy( + prompter, + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + setattr( + strat, + "parse_instruction_fields", + partial( + parse_instruction_fields, + ds_cfg.field_instruction, + ds_cfg.field_input, + ds_cfg.field_output, + ds_cfg.field_system, + system_prompt, + ), + ) + return strat diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index ed79b0f5d..f1fe7d456 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -26,7 +26,7 @@ class AlpacaPrompter: system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" - system_format: str + system_format: str = "{system}" turn_format: str turn_no_input_format: str prompt_style: Optional[PromptStyle] = None @@ -63,13 +63,17 @@ class AlpacaPrompter: # returns the full prompt from instruction and optional input # if a label (=response, =output) is provided, it's also appended. if input: - res = self.system_prompt + self.turn_format.format( - instruction=instruction, input=input - ) + res = ( + self.system_format.format(system=self.system_prompt) + if self.system_prompt + else "" + ) + self.turn_format.format(instruction=instruction, input=input) else: - res = self.system_no_input_prompt + self.turn_no_input_format.format( - instruction=instruction - ) + res = ( + self.system_format.format(system=self.system_no_input_prompt) + if self.system_prompt + else "" + ) + self.turn_no_input_format.format(instruction=instruction) if output: res = f"{res}{output}" yield res diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index d64b06a10..f6a722a82 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -41,6 +41,7 @@ from axolotl.prompters import ( ShareGPTPrompter, SummarizeTLDRPrompter, ) +from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.trainer import ( calculate_total_num_steps, @@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets( split=None, ) elif local_path.is_file(): + ds_type = "json" + if d.ds_type: + ds_type = d.ds_type + elif ".parquet" in d.path: + ds_type = "parquet" + elif ".arrow" in d.path: + ds_type = "arrow" ds = load_dataset( - "json", + ds_type, name=d.name, data_files=d.path, streaming=False, @@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets( ) else: ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0) + + d_base_type = d_prompt_style = None d_type = d.type - d_type_split = d_type.split(":") - d_base_type = d_type_split[0] - d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + if isinstance(d_type, str): + d_type_split = d_type.split(":") + d_base_type = d_type_split[0] + d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None if "train" in ds: ds = ds["train"] - if ds_strategy := load(d.type, tokenizer, cfg): + if ( + "input_ids" in ds.features + and "attention_mask" in ds.features + and "labels" in ds.features + ): + # dataset is already tokenized, just drop it straight in + datasets.append(ds) + elif isinstance(d.type, DictDefault): + ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict()) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) + datasets.append(ds_wrapper) + elif ds_strategy := load(d.type, tokenizer, cfg, d): ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "alpaca":