open orca support

This commit is contained in:
Wing Lian
2023-07-01 00:19:41 -04:00
parent bc8a2e5547
commit 78a1e1fa12
2 changed files with 27 additions and 0 deletions

View File

@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"message_1": "...", "message_2": "..."}
```
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
```json
{"system_prompt": "...", "question": "...", "response": "..."}
```
- `context_qa`: in context question answering from an article
```json
{"article": "...", "question": "...", "answer": "..."}

View File

@@ -75,6 +75,20 @@ class SystemDataPrompter(AlpacaPrompter):
yield res
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
"""
Tokenizing strategy for OpenOrca datasets
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
return (
prompt["question"],
"",
prompt["response"],
prompt["system_prompt"],
)
def load(tokenizer, cfg):
return InstructionWSystemPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.CHAT.value),
@@ -82,3 +96,12 @@ def load(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
SystemDataPrompter(PromptStyle.INSTRUCT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)