From 5000cb3fe76b53128a374ad16a44630e0101b506 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 28 Apr 2025 10:11:06 -0400 Subject: [PATCH] grab sys prompt too from dataset (#2397) [skip ci] * grab sys prompt too from dataset * chore: add field_system to docs --------- Co-authored-by: NanoCode012 --- docs/config.qmd | 4 ++++ src/axolotl/prompt_strategies/chat_template.py | 13 +++++++++++++ 2 files changed, 17 insertions(+) diff --git a/docs/config.qmd b/docs/config.qmd index a67734498..cb39e1d54 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -154,6 +154,10 @@ datasets: # Key containing the messages (default: "messages") field_messages: messages + # Key containing the system message (default: "system") + # If the system message is not present in the dataset sample, it will be loaded from the field_system property. + field_system: system + # Mapping of properties from the input dataset to the chat template. # (default: message_property_mappings={'role':'role', 'content':'content'}) # If a property exists in the template but not in this mapping, the system will attempt diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 918c56329..d16eb34e1 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -33,6 +33,7 @@ class ChatTemplatePrompter(Prompter): message_field_training: Optional[str] = None, message_field_training_detail: Optional[str] = None, field_messages: str = "messages", + field_system: str = "system", roles: Optional[Dict[str, List[str]]] = None, drop_system_message: bool = False, ): @@ -62,6 +63,7 @@ class ChatTemplatePrompter(Prompter): self.message_field_training = message_field_training self.message_field_training_detail = message_field_training_detail self.field_messages = field_messages + self.field_system = field_system self.tokenizer = tokenizer self.processor: Optional[ProcessorMixin] = processor self.chat_template = chat_template @@ -488,6 +490,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def get_conversation_thread(self, prompt): turns = [] + + possible_sys_turn = self.transform_message( + prompt[self.prompter.field_messages][0] + ) + if ( + possible_sys_turn["role"] != "system" + and self.prompter.field_system in prompt + ): + turn = {"role": "system", "content": prompt[self.prompter.field_system]} + turns.append(turn) + for message in prompt[self.prompter.field_messages]: transformed_message = self.transform_message(message)