feat: add chat_template kwargs (#2694) [skip ci]

This commit is contained in:
NanoCode012
2025-06-03 14:25:26 -07:00
committed by GitHub
parent 1d91d905c9
commit d7fa60662e
2 changed files with 8 additions and 2 deletions

View File

@@ -29,12 +29,13 @@ class ChatTemplatePrompter(Prompter):
chat_template: str, chat_template: str,
processor=None, processor=None,
max_length=2048, max_length=2048,
message_property_mappings: Dict[str, str] | None = None, message_property_mappings: dict[str, str] | None = None,
message_field_training: str | None = None, message_field_training: str | None = None,
message_field_training_detail: str | None = None, message_field_training_detail: str | None = None,
field_messages: str = "messages", field_messages: str = "messages",
field_system: str = "system", field_system: str = "system",
roles: Dict[str, List[str]] | None = None, roles: dict[str, list[str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
drop_system_message: bool = False, drop_system_message: bool = False,
): ):
# check if message_property_mappings is None or empty dict # check if message_property_mappings is None or empty dict
@@ -68,6 +69,7 @@ class ChatTemplatePrompter(Prompter):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.processor: ProcessorMixin | None = processor self.processor: ProcessorMixin | None = processor
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.max_length = max_length self.max_length = max_length
self.drop_system_message = drop_system_message self.drop_system_message = drop_system_message
@@ -85,6 +87,7 @@ class ChatTemplatePrompter(Prompter):
chat_template=self.chat_template, chat_template=self.chat_template,
tokenize=False, tokenize=False,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
**self.chat_template_kwargs,
) )
batch = self.processor( batch = self.processor(
text=text, text=text,
@@ -103,6 +106,7 @@ class ChatTemplatePrompter(Prompter):
conversation, conversation,
add_generation_prompt=add_generation_prompt, add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template, chat_template=self.chat_template,
**self.chat_template_kwargs,
) )
def get_offsets_for_train_detail( def get_offsets_for_train_detail(
@@ -779,6 +783,7 @@ class StrategyLoader:
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": chat_template_string, "chat_template": chat_template_string,
"chat_template_kwargs": cfg.get("chat_template_kwargs", {}),
"message_property_mappings": dataset_config.get( "message_property_mappings": dataset_config.get(
"message_property_mappings", {} "message_property_mappings", {}
), ),

View File

@@ -314,6 +314,7 @@ class AxolotlInputConfig(
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
) | None = None ) | None = None
chat_template_jinja: str | None = None chat_template_jinja: str | None = None
chat_template_kwargs: dict[str, Any] | None = None
eot_tokens: list[str] | None = None eot_tokens: list[str] | None = None
default_system_message: str | None = None default_system_message: str | None = None