Fix mypy typing

This commit is contained in:
NanoCode012
2023-05-29 18:13:39 +09:00
parent f1232b35ba
commit e9650d3ae4
8 changed files with 190 additions and 33 deletions

View File

@@ -3,7 +3,7 @@
import dataclasses
import logging
from enum import auto, Enum
from typing import List, Union, Generator
from typing import List, Optional, Union, Generator
IGNORE_TOKEN_ID = -100
@@ -24,7 +24,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
prompt_style = None
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
@@ -231,18 +231,18 @@ class Conversation:
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
sep2: Optional[str] = None
def get_prompt(self) -> Generator[str, None, None]:
seps = [self.sep, self.sep2]
preamble = self.system + seps[0]
# seps = [self.sep, self.sep2]
preamble = self.system + self.sep
yield preamble
for _, (role, message) in enumerate(self.messages):
if message:
yield (role + ":", " " + message)
yield role + ":" + " " + message
else:
logging.warning(f"role with empty message: {role}")
yield (role + ":",)
yield role + ":"
def copy(self):
return Conversation(