grab sys prompt too from dataset (#2397) [skip ci]

* grab sys prompt too from dataset

* chore: add field_system to docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-04-28 10:11:06 -04:00
committed by GitHub
parent 170cdb5be9
commit 5000cb3fe7
2 changed files with 17 additions and 0 deletions

View File

@@ -154,6 +154,10 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key containing the system message (default: "system")
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
field_system: system
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt

View File

@@ -33,6 +33,7 @@ class ChatTemplatePrompter(Prompter):
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -62,6 +63,7 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
@@ -488,6 +490,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)