support for json data as completion

This commit is contained in:
Wing Lian
2023-11-25 16:05:04 -05:00
parent 1115c501b8
commit da154e6d56

View File

@@ -1,6 +1,7 @@
""" """
Basic completion text Basic completion text
""" """
import json
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
return next(iter(self.prompter.build_prompt(instruction, input, response))) return next(iter(self.prompter.build_prompt(instruction, input, response)))
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
"""
Strategy to return the stringified JSON of the entire row as the training data
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
json.dumps(prompt),
"",
"",
)
class CompletionPrompter: class CompletionPrompter:
""" """
Prompter for completion Prompter for completion
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy( strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(), CompletionPrompter(),
tokenizer, tokenizer,
cfg.train_on_inputs, True,
cfg.sequence_len, cfg.sequence_len,
max_length=cfg.sequence_len * 64, max_length=cfg.sequence_len * 64,
) )
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat.field = ds_cfg["field"] strat.field = ds_cfg["field"]
return strat return strat
def load_json(tokenizer, cfg):
strat = CompletionJSONPromptTokenizationStrategy(
CompletionPrompter(),
tokenizer,
True,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat