Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)

* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.

* use a strict option for hanedling incorrect turn data

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
MilesQLi
2023-10-27 22:00:16 -04:00
committed by GitHub
parent d3193beac3
commit 0800885e2f

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return SimpleShareGPTPromptTokenizingStrategy(
strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg):
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = True
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
def get_conversation_thread(self, prompt):
return prompt["conversations"]
conversations = prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):