support for json data as completion

This commit is contained in:
Wing Lian
2023-11-25 16:05:04 -05:00
parent 1115c501b8
commit da154e6d56

View File

@@ -1,6 +1,7 @@
"""
Basic completion text
"""
import json
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
"""
Strategy to return the stringified JSON of the entire row as the training data
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
json.dumps(prompt),
"",
"",
)
class CompletionPrompter:
"""
Prompter for completion
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
True,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat.field = ds_cfg["field"]
return strat
def load_json(tokenizer, cfg):
strat = CompletionJSONPromptTokenizationStrategy(
CompletionPrompter(),
tokenizer,
True,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat