Lint pygmalion

This commit is contained in:
NanoCode012
2023-05-29 14:05:22 +09:00
parent 7eb33a77dd
commit 01c8a333b3

View File

@@ -1,3 +1,5 @@
"""Module containing the PygmalionPromptTokenizingStrategy and PygmalionPrompter class"""
import copy
import logging
from collections import defaultdict
@@ -9,10 +11,14 @@ IGNORE_TOKEN_ID = -100
class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Pygmalion.
"""
bot_prefix_token_ids = []
def __init__(self, prompter, tokenizer, *args, **kwargs):
super().__init__(prompter, tokenizer)
super().__init__(prompter, tokenizer, *args, **kwargs)
res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
self.bot_prefix_token_ids = res["input_ids"]
@@ -23,7 +29,7 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
"labels": [],
}
current_len = 0
for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
for _, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
role, message = part
if role == "system":
prefix = "<|system|>"
@@ -96,10 +102,16 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
class PygmalionPrompter:
"""
Prompter for Pygmalion.
"""
def __init__(self, *args, **kwargs):
pass
def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
def build_prompt(
self, source, *args, **kwargs # pylint: disable=unused-argument
) -> Generator[str, None, None]:
for msg in source:
yield msg["role"], msg["value"]