grab sys prompt too from dataset (#2397) [skip ci]

* grab sys prompt too from dataset

* chore: add field_system to docs

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
Wing Lian
2025-04-28 10:11:06 -04:00
committed by GitHub
parent 170cdb5be9
commit 5000cb3fe7
2 changed files with 17 additions and 0 deletions

View File

@@ -33,6 +33,7 @@ class ChatTemplatePrompter(Prompter):
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
field_system: str = "system",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -62,6 +63,7 @@ class ChatTemplatePrompter(Prompter):
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.field_system = field_system
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
@@ -488,6 +490,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
possible_sys_turn = self.transform_message(
prompt[self.prompter.field_messages][0]
)
if (
possible_sys_turn["role"] != "system"
and self.prompter.field_system in prompt
):
turn = {"role": "system", "content": prompt[self.prompter.field_system]}
turns.append(turn)
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)