support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files * fix user defined dataset types * fix for system prompts * fix tests * fix checks for parquet and arrow * aha moment that d.data_files isn't used * add documentation for ds_type to add support for parquet and arrow
This commit is contained in:
@@ -392,6 +392,7 @@ datasets:
|
|||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
|
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
||||||
data_files: # path to source data files
|
data_files: # path to source data files
|
||||||
shards: # number of shards to split data into
|
shards: # number of shards to split data into
|
||||||
name: # name of dataset configuration to load
|
name: # name of dataset configuration to load
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
|
||||||
def load(strategy, tokenizer, cfg):
|
|
||||||
|
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||||
try:
|
try:
|
||||||
load_fn = "load"
|
load_fn = "load"
|
||||||
if strategy.split(".")[-1].startswith("load_"):
|
if strategy.split(".")[-1].startswith("load_"):
|
||||||
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
|
|||||||
strategy = ".".join(strategy.split(".")[:-1])
|
strategy = ".".join(strategy.split(".")[:-1])
|
||||||
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
return func(tokenizer, cfg)
|
load_kwargs = {}
|
||||||
|
if strategy == "user_defined":
|
||||||
|
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||||
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
|
|||||||
Alpaca Style Prompter that uses system prompts from the dataset
|
Alpaca Style Prompter that uses system prompts from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
system_format: str = "### System:\n{system}\n\n"
|
||||||
|
|
||||||
def build_prompt_w_system(
|
def build_prompt_w_system(
|
||||||
self,
|
self,
|
||||||
system: str,
|
system: str,
|
||||||
|
|||||||
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
98
src/axolotl/prompt_strategies/user_defined.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
User Defined prompts with configuration from the YML config
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||||
|
InstructionWSystemPromptTokenizingStrategy,
|
||||||
|
SystemDataPrompter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UserDefinedDatasetConfig:
|
||||||
|
"""
|
||||||
|
dataclass configuration representing a userdefined dataset type
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt: str = ""
|
||||||
|
field_system: str = "system"
|
||||||
|
field_instruction: str = "instruction"
|
||||||
|
field_input: str = "input"
|
||||||
|
field_output: str = "output"
|
||||||
|
format: str = "{instruction} {input} "
|
||||||
|
no_input_format: str = "{instruction} "
|
||||||
|
system_format: str = "{system}"
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return getattr(self, item)
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Prompt Tokenization Strategy for user defined prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
||||||
|
if not ds_cfg:
|
||||||
|
raise ValueError("Missing dataset prompt configuration")
|
||||||
|
|
||||||
|
system_prompt = ""
|
||||||
|
if ds_cfg.system_prompt:
|
||||||
|
system_prompt = ds_cfg.system_prompt
|
||||||
|
|
||||||
|
def parse_instruction_fields(
|
||||||
|
field_instruction,
|
||||||
|
field_input,
|
||||||
|
field_output,
|
||||||
|
field_system,
|
||||||
|
system_prompt,
|
||||||
|
prompt,
|
||||||
|
) -> Tuple[str, str, str, str]:
|
||||||
|
return (
|
||||||
|
prompt[field_instruction],
|
||||||
|
prompt[field_input] if field_input in prompt else "",
|
||||||
|
prompt[field_output] if field_output in prompt else "",
|
||||||
|
prompt[field_system] if field_system in prompt else system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_format = ds_cfg.format
|
||||||
|
turn_no_input_format = ds_cfg.no_input_format
|
||||||
|
system_format = ds_cfg.system_format
|
||||||
|
|
||||||
|
class UserDefinedPrompter(SystemDataPrompter):
|
||||||
|
"""
|
||||||
|
Prompter for user defined prompts
|
||||||
|
"""
|
||||||
|
|
||||||
|
def match_prompt_style(self):
|
||||||
|
self.turn_format = turn_format
|
||||||
|
self.turn_no_input_format = turn_no_input_format
|
||||||
|
self.system_format = system_format
|
||||||
|
|
||||||
|
prompter = UserDefinedPrompter()
|
||||||
|
|
||||||
|
strat = UserDefinedPromptTokenizationStrategy(
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
strat,
|
||||||
|
"parse_instruction_fields",
|
||||||
|
partial(
|
||||||
|
parse_instruction_fields,
|
||||||
|
ds_cfg.field_instruction,
|
||||||
|
ds_cfg.field_input,
|
||||||
|
ds_cfg.field_output,
|
||||||
|
ds_cfg.field_system,
|
||||||
|
system_prompt,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return strat
|
||||||
@@ -26,7 +26,7 @@ class AlpacaPrompter:
|
|||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||||
system_format: str
|
system_format: str = "{system}"
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
prompt_style: Optional[PromptStyle] = None
|
prompt_style: Optional[PromptStyle] = None
|
||||||
@@ -63,13 +63,17 @@ class AlpacaPrompter:
|
|||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
if input:
|
if input:
|
||||||
res = self.system_prompt + self.turn_format.format(
|
res = (
|
||||||
instruction=instruction, input=input
|
self.system_format.format(system=self.system_prompt)
|
||||||
)
|
if self.system_prompt
|
||||||
|
else ""
|
||||||
|
) + self.turn_format.format(instruction=instruction, input=input)
|
||||||
else:
|
else:
|
||||||
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
res = (
|
||||||
instruction=instruction
|
self.system_format.format(system=self.system_no_input_prompt)
|
||||||
)
|
if self.system_prompt
|
||||||
|
else ""
|
||||||
|
) + self.turn_no_input_format.format(instruction=instruction)
|
||||||
if output:
|
if output:
|
||||||
res = f"{res}{output}"
|
res = f"{res}{output}"
|
||||||
yield res
|
yield res
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from axolotl.prompters import (
|
|||||||
ShareGPTPrompter,
|
ShareGPTPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
@@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
|
ds_type = "json"
|
||||||
|
if d.ds_type:
|
||||||
|
ds_type = d.ds_type
|
||||||
|
elif ".parquet" in d.path:
|
||||||
|
ds_type = "parquet"
|
||||||
|
elif ".arrow" in d.path:
|
||||||
|
ds_type = "arrow"
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
"json",
|
ds_type,
|
||||||
name=d.name,
|
name=d.name,
|
||||||
data_files=d.path,
|
data_files=d.path,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
@@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||||
|
|
||||||
|
d_base_type = d_prompt_style = None
|
||||||
d_type = d.type
|
d_type = d.type
|
||||||
d_type_split = d_type.split(":")
|
if isinstance(d_type, str):
|
||||||
d_base_type = d_type_split[0]
|
d_type_split = d_type.split(":")
|
||||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
d_base_type = d_type_split[0]
|
||||||
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
if ds_strategy := load(d.type, tokenizer, cfg):
|
if (
|
||||||
|
"input_ids" in ds.features
|
||||||
|
and "attention_mask" in ds.features
|
||||||
|
and "labels" in ds.features
|
||||||
|
):
|
||||||
|
# dataset is already tokenized, just drop it straight in
|
||||||
|
datasets.append(ds)
|
||||||
|
elif isinstance(d.type, DictDefault):
|
||||||
|
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||||
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
|
datasets.append(ds_wrapper)
|
||||||
|
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "alpaca":
|
elif d_base_type == "alpaca":
|
||||||
|
|||||||
Reference in New Issue
Block a user