diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 6ac7cbafe..bb779e8df 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -39,6 +39,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) if ds_cfg and "strict" in ds_cfg: strategy.strict = ds_cfg["strict"] + if ds_cfg and "field_messages" in ds_cfg: + strategy.field_messages = ds_cfg["field_messages"] return strategy @@ -83,6 +85,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ _strict = False + _field_messages = "conversations" @property def strict(self): @@ -92,8 +95,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): def strict(self, strict): self._strict = strict + @property + def field_messages(self): + return self._strict + + @field_messages.setter + def field_messages(self, field_messages): + self._field_messages = field_messages + def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + conversations = prompt[self.field_messages] if self.strict: return conversations role_key = "from"