support custom field for completion from yml (#580)
* support custom field for completion from yml * remove legacy completion check and add doc * update README docs
This commit is contained in:
@@ -322,6 +322,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: EleutherAI/pile
|
- path: EleutherAI/pile
|
||||||
name: enron_emails
|
name: enron_emails
|
||||||
type: completion # format from earlier
|
type: completion # format from earlier
|
||||||
|
field: text # Optional[str] default: text, field to use for completion data
|
||||||
|
|
||||||
# huggingface repo with multiple named configurations/subsets
|
# huggingface repo with multiple named configurations/subsets
|
||||||
datasets:
|
datasets:
|
||||||
@@ -444,6 +445,9 @@ datasets:
|
|||||||
# 'no_input_format' cannot include {input}
|
# 'no_input_format' cannot include {input}
|
||||||
no_input_format: "{instruction} "
|
no_input_format: "{instruction} "
|
||||||
|
|
||||||
|
# for completions datsets, uses the provided field if not `text`
|
||||||
|
field:
|
||||||
|
|
||||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Module to load prompt strategies."""
|
"""Module to load prompt strategies."""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
import inspect
|
||||||
|
|
||||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||||
|
|
||||||
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
|
|||||||
load_kwargs = {}
|
load_kwargs = {}
|
||||||
if strategy == "user_defined":
|
if strategy == "user_defined":
|
||||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||||
|
else:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
if "ds_cfg" in sig.parameters:
|
||||||
|
load_kwargs["ds_cfg"] = ds_cfg
|
||||||
return func(tokenizer, cfg, **load_kwargs)
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
return None
|
return None
|
||||||
|
|||||||
20
src/axolotl/prompt_strategies/completion.py
Normal file
20
src/axolotl/prompt_strategies/completion.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Basic completion text
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
|
||||||
|
from axolotl.prompters import CompletionPrompter
|
||||||
|
|
||||||
|
|
||||||
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
strat = CompletionPromptTokenizingStrategy(
|
||||||
|
CompletionPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "field" in ds_cfg:
|
||||||
|
strat.field = ds_cfg["field"]
|
||||||
|
|
||||||
|
return strat
|
||||||
@@ -245,8 +245,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for Completion prompts.
|
Tokenizing strategy for Completion prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_field: str = "text"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field(self) -> str:
|
||||||
|
return self._field
|
||||||
|
|
||||||
|
@field.setter
|
||||||
|
def field(self, new_field: str):
|
||||||
|
self._field = new_field
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
prompt[self.field],
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
full_prompt = self._build_full_prompt(prompt["text"], None, None)
|
(
|
||||||
|
instruction,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.parse_instruction_fields(prompt)
|
||||||
|
|
||||||
|
full_prompt = self._build_full_prompt(instruction, None, None)
|
||||||
tokenized_full_prompt = self._tokenize(full_prompt)
|
tokenized_full_prompt = self._tokenize(full_prompt)
|
||||||
|
|
||||||
return tokenized_full_prompt
|
return tokenized_full_prompt
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ from axolotl.prompt_tokenizers import (
|
|||||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
AlpacaReflectionPTStrategy,
|
AlpacaReflectionPTStrategy,
|
||||||
CompletionPromptTokenizingStrategy,
|
|
||||||
GPTeacherPromptTokenizingStrategy,
|
GPTeacherPromptTokenizingStrategy,
|
||||||
JeopardyPromptTokenizingStrategy,
|
JeopardyPromptTokenizingStrategy,
|
||||||
OpenAssistantPromptTokenizingStrategy,
|
OpenAssistantPromptTokenizingStrategy,
|
||||||
@@ -31,7 +30,6 @@ from axolotl.prompt_tokenizers import (
|
|||||||
)
|
)
|
||||||
from axolotl.prompters import (
|
from axolotl.prompters import (
|
||||||
AlpacaPrompter,
|
AlpacaPrompter,
|
||||||
CompletionPrompter,
|
|
||||||
GPTeacherPrompter,
|
GPTeacherPrompter,
|
||||||
JeopardyPrompter,
|
JeopardyPrompter,
|
||||||
MultipleChoiceConcisePrompter,
|
MultipleChoiceConcisePrompter,
|
||||||
@@ -327,15 +325,6 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||||
datasets.append(ds_wrapper)
|
datasets.append(ds_wrapper)
|
||||||
elif d_base_type == "completion":
|
|
||||||
ds_strategy = CompletionPromptTokenizingStrategy(
|
|
||||||
CompletionPrompter(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
|
||||||
datasets.append(ds_wrapper)
|
|
||||||
else:
|
else:
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if ":load_" in d.type:
|
if ":load_" in d.type:
|
||||||
|
|||||||
Reference in New Issue
Block a user