From d7fa60662ea1b65d53d5ff5d3f4fcf4a590dd9ea Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 3 Jun 2025 14:25:26 -0700 Subject: [PATCH] feat: add chat_template kwargs (#2694) [skip ci] --- src/axolotl/prompt_strategies/chat_template.py | 9 +++++++-- src/axolotl/utils/schemas/config.py | 1 + 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index ebb151876..a0fd8d911 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -29,12 +29,13 @@ class ChatTemplatePrompter(Prompter): chat_template: str, processor=None, max_length=2048, - message_property_mappings: Dict[str, str] | None = None, + message_property_mappings: dict[str, str] | None = None, message_field_training: str | None = None, message_field_training_detail: str | None = None, field_messages: str = "messages", field_system: str = "system", - roles: Dict[str, List[str]] | None = None, + roles: dict[str, list[str]] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, drop_system_message: bool = False, ): # check if message_property_mappings is None or empty dict @@ -68,6 +69,7 @@ class ChatTemplatePrompter(Prompter): self.tokenizer = tokenizer self.processor: ProcessorMixin | None = processor self.chat_template = chat_template + self.chat_template_kwargs = chat_template_kwargs or {} self.max_length = max_length self.drop_system_message = drop_system_message @@ -85,6 +87,7 @@ class ChatTemplatePrompter(Prompter): chat_template=self.chat_template, tokenize=False, add_generation_prompt=add_generation_prompt, + **self.chat_template_kwargs, ) batch = self.processor( text=text, @@ -103,6 +106,7 @@ class ChatTemplatePrompter(Prompter): conversation, add_generation_prompt=add_generation_prompt, chat_template=self.chat_template, + **self.chat_template_kwargs, ) def get_offsets_for_train_detail( @@ -779,6 +783,7 @@ class StrategyLoader: prompter_params = { "tokenizer": tokenizer, "chat_template": chat_template_string, + "chat_template_kwargs": cfg.get("chat_template_kwargs", {}), "message_property_mappings": dataset_config.get( "message_property_mappings", {} ), diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e7bd16892..e5f105053 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -314,6 +314,7 @@ class AxolotlInputConfig( | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] ) | None = None chat_template_jinja: str | None = None + chat_template_kwargs: dict[str, Any] | None = None eot_tokens: list[str] | None = None default_system_message: str | None = None