diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 5f0e7a895..8b452ae19 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -86,6 +86,8 @@ def build_loader( ) if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"): strategy.strict = ds_cfg["strict"] + if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] return strategy return _load @@ -97,6 +99,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ _strict = False + _messages = "conversations" @property def strict(self): @@ -106,8 +109,16 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): def strict(self, strict): self._strict = strict + @property + def messages(self): + return self._messages + + @messages.setter + def messages(self, messages): + self._messages = messages + def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + conversations = prompt[self.messages] if self.strict: return conversations role_key = "from" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 31750ac15..b6eafa7a3 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -109,6 +109,7 @@ class SFTDataset(BaseModel): field: Optional[str] = None field_human: Optional[str] = None field_model: Optional[str] = None + field_messages: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None