From da154e6d5631815ac64c130d69cd6350af8f85ad Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 25 Nov 2023 16:05:04 -0500 Subject: [PATCH] support for json data as completion --- src/axolotl/prompt_strategies/completion.py | 28 ++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/completion.py b/src/axolotl/prompt_strategies/completion.py index 3285e667c..5ec849334 100644 --- a/src/axolotl/prompt_strategies/completion.py +++ b/src/axolotl/prompt_strategies/completion.py @@ -1,6 +1,7 @@ """ Basic completion text """ +import json from collections import defaultdict from typing import Any, Dict, Generator, Optional, Tuple @@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): return next(iter(self.prompter.build_prompt(instruction, input, response))) +class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy): + """ + Strategy to return the stringified JSON of the entire row as the training data + """ + + def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]: + return ( + json.dumps(prompt), + "", + "", + ) + + class CompletionPrompter: """ Prompter for completion @@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strat = CompletionPromptTokenizingStrategy( CompletionPrompter(), tokenizer, - cfg.train_on_inputs, + True, cfg.sequence_len, max_length=cfg.sequence_len * 64, ) @@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strat.field = ds_cfg["field"] return strat + + +def load_json(tokenizer, cfg): + strat = CompletionJSONPromptTokenizationStrategy( + CompletionPrompter(), + tokenizer, + True, + cfg.sequence_len, + max_length=cfg.sequence_len * 64, + ) + + return strat