black formatting

This commit is contained in:
Wing Lian
2023-05-10 16:01:08 -04:00
parent 7a490a4646
commit 2bc1a5bde1
11 changed files with 132 additions and 64 deletions

View File

@@ -36,10 +36,7 @@ class JeopardyPrompter(AlpacaPrompter):
class CompletionPrompter(AlpacaPrompter):
def build_prompt(
self,
instruction: str
) -> str:
def build_prompt(self, instruction: str) -> str:
return instruction
def get_response(self, output: str) -> str:
@@ -75,7 +72,9 @@ class ReflectAlpacaPrompter:
else:
res = self.prompt_no_input.format(instruction=instruction)
if output and reflection and corrected:
label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
label = self.agent_label.format(
output=output, reflection=reflection, corrected=corrected
)
res = f"{res}{label}"
return res
@@ -200,9 +199,13 @@ class ShareGPTPrompter:
if len(parts) != 2:
break
parts[0] += sep
round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
round_len = (
len(tokenizer(rou)["input_ids"]) - 1
) # -1 ignores the bos_token generated for this
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
instruction_len = (
len(tokenizer(parts[0].strip())["input_ids"]) - 1
) # -1 ignores the bos_token generated for this
target[cur_len : cur_len + instruction_len] = [
IGNORE_TOKEN_ID
] * instruction_len
@@ -212,7 +215,7 @@ class ShareGPTPrompter:
break
# Fix: Truncate the target to have the same length as input_ids
target = target[:len(tokenized_result["input_ids"])]
target = target[: len(tokenized_result["input_ids"])]
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
attention_mask = [