add support for defined train split (#654)
This commit is contained in:
10
README.md
10
README.md
@@ -250,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"article": "...", "question": "...", "answer": "..."}
|
{"article": "...", "question": "...", "answer": "..."}
|
||||||
```
|
```
|
||||||
|
- `context_qa.load_v2`: in context question answering (alternate)
|
||||||
|
```json
|
||||||
|
{"context": "...", "question": "...", "answer": "..."}
|
||||||
|
```
|
||||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||||
```json
|
```json
|
||||||
{"article": "...", "unanswerable_question": "..."}
|
{"article": "...", "unanswerable_question": "..."}
|
||||||
@@ -356,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
ds_type: json # see other options below
|
ds_type: json # see other options below
|
||||||
type: alpaca
|
type: alpaca
|
||||||
|
|
||||||
|
# dataset with splits, but no train split
|
||||||
|
dataset:
|
||||||
|
- path: knowrohit07/know_sql
|
||||||
|
type: context_qa.load_v2
|
||||||
|
train_on_split: validation
|
||||||
```
|
```
|
||||||
|
|
||||||
- loading
|
- loading
|
||||||
|
|||||||
@@ -24,6 +24,15 @@ def load(tokenizer, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_v2(tokenizer, cfg):
|
||||||
|
return ContextQaV2PromptTokenizingStrategy(
|
||||||
|
ContextV2Prompter(),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AlpacaContextPrompter(AlpacaPrompter):
|
class AlpacaContextPrompter(AlpacaPrompter):
|
||||||
"""
|
"""
|
||||||
Customized system prompted for concise QA
|
Customized system prompted for concise QA
|
||||||
@@ -50,6 +59,38 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Tokenization Strategy to combine in-context article with a question and answer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
"Context: "
|
||||||
|
+ prompt["context"]
|
||||||
|
+ "\nQuestion: "
|
||||||
|
+ prompt["question"]
|
||||||
|
+ "\n",
|
||||||
|
"",
|
||||||
|
"Answer: " + prompt["answer"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextV2Prompter(AlpacaPrompter):
|
||||||
|
"""
|
||||||
|
Customized system prompted for concise QA
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = ""
|
||||||
|
system_no_input_prompt = ""
|
||||||
|
|
||||||
|
def match_prompt_style(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
self.turn_format = "{instruction}\n{input}"
|
||||||
|
self.turn_no_input_format = "{instruction}"
|
||||||
|
self.system_format = "{system}"
|
||||||
|
|
||||||
|
|
||||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||||
InstructionPromptTokenizingStrategy
|
InstructionPromptTokenizingStrategy
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -247,6 +247,16 @@ def load_tokenized_prepared_datasets(
|
|||||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||||
if "train" in ds:
|
if "train" in ds:
|
||||||
ds = ds["train"]
|
ds = ds["train"]
|
||||||
|
elif (
|
||||||
|
isinstance(ds, DatasetDict)
|
||||||
|
and d.train_on_split
|
||||||
|
and d.train_on_split in ds
|
||||||
|
):
|
||||||
|
ds = ds[d.train_on_split]
|
||||||
|
elif isinstance(ds, DatasetDict):
|
||||||
|
raise ValueError(
|
||||||
|
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
"input_ids" in ds.features
|
"input_ids" in ds.features
|
||||||
and "attention_mask" in ds.features
|
and "attention_mask" in ds.features
|
||||||
|
|||||||
Reference in New Issue
Block a user