support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files * fix user defined dataset types * fix for system prompts * fix tests * fix checks for parquet and arrow * aha moment that d.data_files isn't used * add documentation for ds_type to add support for parquet and arrow
This commit is contained in:
@@ -392,6 +392,7 @@ datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
||||
data_files: # path to source data files
|
||||
shards: # number of shards to split data into
|
||||
name: # name of dataset configuration to load
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import importlib
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
def load(strategy, tokenizer, cfg):
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||
func = getattr(mod, load_fn)
|
||||
return func(tokenizer, cfg)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return None
|
||||
|
||||
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
|
||||
Alpaca Style Prompter that uses system prompts from the dataset
|
||||
"""
|
||||
|
||||
system_format: str = "### System:\n{system}\n\n"
|
||||
|
||||
def build_prompt_w_system(
|
||||
self,
|
||||
system: str,
|
||||
|
||||
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
User Defined prompts with configuration from the YML config
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||
InstructionWSystemPromptTokenizingStrategy,
|
||||
SystemDataPrompter,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserDefinedDatasetConfig:
|
||||
"""
|
||||
dataclass configuration representing a userdefined dataset type
|
||||
"""
|
||||
|
||||
system_prompt: str = ""
|
||||
field_system: str = "system"
|
||||
field_instruction: str = "instruction"
|
||||
field_input: str = "input"
|
||||
field_output: str = "output"
|
||||
format: str = "{instruction} {input} "
|
||||
no_input_format: str = "{instruction} "
|
||||
system_format: str = "{system}"
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||
"""
|
||||
Prompt Tokenization Strategy for user defined prompts
|
||||
"""
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
||||
if not ds_cfg:
|
||||
raise ValueError("Missing dataset prompt configuration")
|
||||
|
||||
system_prompt = ""
|
||||
if ds_cfg.system_prompt:
|
||||
system_prompt = ds_cfg.system_prompt
|
||||
|
||||
def parse_instruction_fields(
|
||||
field_instruction,
|
||||
field_input,
|
||||
field_output,
|
||||
field_system,
|
||||
system_prompt,
|
||||
prompt,
|
||||
) -> Tuple[str, str, str, str]:
|
||||
return (
|
||||
prompt[field_instruction],
|
||||
prompt[field_input] if field_input in prompt else "",
|
||||
prompt[field_output] if field_output in prompt else "",
|
||||
prompt[field_system] if field_system in prompt else system_prompt,
|
||||
)
|
||||
|
||||
turn_format = ds_cfg.format
|
||||
turn_no_input_format = ds_cfg.no_input_format
|
||||
system_format = ds_cfg.system_format
|
||||
|
||||
class UserDefinedPrompter(SystemDataPrompter):
|
||||
"""
|
||||
Prompter for user defined prompts
|
||||
"""
|
||||
|
||||
def match_prompt_style(self):
|
||||
self.turn_format = turn_format
|
||||
self.turn_no_input_format = turn_no_input_format
|
||||
self.system_format = system_format
|
||||
|
||||
prompter = UserDefinedPrompter()
|
||||
|
||||
strat = UserDefinedPromptTokenizationStrategy(
|
||||
prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
setattr(
|
||||
strat,
|
||||
"parse_instruction_fields",
|
||||
partial(
|
||||
parse_instruction_fields,
|
||||
ds_cfg.field_instruction,
|
||||
ds_cfg.field_input,
|
||||
ds_cfg.field_output,
|
||||
ds_cfg.field_system,
|
||||
system_prompt,
|
||||
),
|
||||
)
|
||||
return strat
|
||||
@@ -26,7 +26,7 @@ class AlpacaPrompter:
|
||||
|
||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||
system_format: str
|
||||
system_format: str = "{system}"
|
||||
turn_format: str
|
||||
turn_no_input_format: str
|
||||
prompt_style: Optional[PromptStyle] = None
|
||||
@@ -63,13 +63,17 @@ class AlpacaPrompter:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = self.system_prompt + self.turn_format.format(
|
||||
instruction=instruction, input=input
|
||||
)
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
||||
instruction=instruction
|
||||
)
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
@@ -41,6 +41,7 @@ from axolotl.prompters import (
|
||||
ShareGPTPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
@@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
ds_type = d.ds_type
|
||||
elif ".parquet" in d.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in d.path:
|
||||
ds_type = "arrow"
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
ds_type,
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
@@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if isinstance(d_type, str):
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if "train" in ds:
|
||||
ds = ds["train"]
|
||||
if ds_strategy := load(d.type, tokenizer, cfg):
|
||||
if (
|
||||
"input_ids" in ds.features
|
||||
and "attention_mask" in ds.features
|
||||
and "labels" in ds.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
datasets.append(ds)
|
||||
elif isinstance(d.type, DictDefault):
|
||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
|
||||
Reference in New Issue
Block a user