support for custom messages field in sharegpt (#1651)

This commit is contained in:
Wing Lian
2024-05-23 13:03:22 -04:00
committed by GitHub
parent 84bb8061ba
commit bbfed318bc
2 changed files with 13 additions and 1 deletions

View File

@@ -86,6 +86,8 @@ def build_loader(
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
@@ -97,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
_strict = False
_messages = "conversations"
@property
def strict(self):
@@ -106,8 +109,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
def strict(self, strict):
self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
conversations = prompt[self.messages]
if self.strict:
return conversations
role_key = "from"

View File

@@ -109,6 +109,7 @@ class SFTDataset(BaseModel):
field: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None