Compare commits
1 Commits
sp-rl
...
completion
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da154e6d56 |
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Basic completion text
|
||||
"""
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||
|
||||
|
||||
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
|
||||
"""
|
||||
Strategy to return the stringified JSON of the entire row as the training data
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
json.dumps(prompt),
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
|
||||
class CompletionPrompter:
|
||||
"""
|
||||
Prompter for completion
|
||||
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strat = CompletionPromptTokenizingStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
True,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
strat.field = ds_cfg["field"]
|
||||
|
||||
return strat
|
||||
|
||||
|
||||
def load_json(tokenizer, cfg):
|
||||
strat = CompletionJSONPromptTokenizationStrategy(
|
||||
CompletionPrompter(),
|
||||
tokenizer,
|
||||
True,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
|
||||
return strat
|
||||
|
||||
Reference in New Issue
Block a user