make the conversations/messages field configurable for sharegpt
This commit is contained in:
@@ -39,6 +39,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
if ds_cfg and "field_messages" in ds_cfg:
|
||||
strategy.field_messages = ds_cfg["field_messages"]
|
||||
return strategy
|
||||
|
||||
|
||||
@@ -83,6 +85,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
|
||||
_strict = False
|
||||
_field_messages = "conversations"
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
@@ -92,8 +95,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def field_messages(self):
|
||||
return self._strict
|
||||
|
||||
@field_messages.setter
|
||||
def field_messages(self, field_messages):
|
||||
self._field_messages = field_messages
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
conversations = prompt[self.field_messages]
|
||||
if self.strict:
|
||||
return conversations
|
||||
role_key = "from"
|
||||
|
||||
Reference in New Issue
Block a user