From b7fe46579df65eab85bc5514581eecfa035c3ab9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 8 Mar 2024 08:08:29 -0500 Subject: [PATCH] make the conversations/messages field configurable for sharegpt --- src/axolotl/prompt_strategies/sharegpt.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 6ac7cbafe..bb779e8df 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -39,6 +39,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) if ds_cfg and "strict" in ds_cfg: strategy.strict = ds_cfg["strict"] + if ds_cfg and "field_messages" in ds_cfg: + strategy.field_messages = ds_cfg["field_messages"] return strategy @@ -83,6 +85,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ _strict = False + _field_messages = "conversations" @property def strict(self): @@ -92,8 +95,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): def strict(self, strict): self._strict = strict + @property + def field_messages(self): + return self._strict + + @field_messages.setter + def field_messages(self, field_messages): + self._field_messages = field_messages + def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + conversations = prompt[self.field_messages] if self.strict: return conversations role_key = "from"