Compare commits
1 Commits
fix-previe
...
completion
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da154e6d56 |
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Basic completion text
|
Basic completion text
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Generator, Optional, Tuple
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
@@ -64,6 +65,19 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|||||||
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
return next(iter(self.prompter.build_prompt(instruction, input, response)))
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionJSONPromptTokenizationStrategy(CompletionPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
Strategy to return the stringified JSON of the entire row as the training data
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||||
|
return (
|
||||||
|
json.dumps(prompt),
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompletionPrompter:
|
class CompletionPrompter:
|
||||||
"""
|
"""
|
||||||
Prompter for completion
|
Prompter for completion
|
||||||
@@ -82,7 +96,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strat = CompletionPromptTokenizingStrategy(
|
strat = CompletionPromptTokenizingStrategy(
|
||||||
CompletionPrompter(),
|
CompletionPrompter(),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
True,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
max_length=cfg.sequence_len * 64,
|
max_length=cfg.sequence_len * 64,
|
||||||
)
|
)
|
||||||
@@ -90,3 +104,15 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
strat.field = ds_cfg["field"]
|
strat.field = ds_cfg["field"]
|
||||||
|
|
||||||
return strat
|
return strat
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(tokenizer, cfg):
|
||||||
|
strat = CompletionJSONPromptTokenizationStrategy(
|
||||||
|
CompletionPrompter(),
|
||||||
|
tokenizer,
|
||||||
|
True,
|
||||||
|
cfg.sequence_len,
|
||||||
|
max_length=cfg.sequence_len * 64,
|
||||||
|
)
|
||||||
|
|
||||||
|
return strat
|
||||||
|
|||||||
Reference in New Issue
Block a user