feat: add chat_template kwargs (#2694) [skip ci]
This commit is contained in:
@@ -29,12 +29,13 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template: str,
|
chat_template: str,
|
||||||
processor=None,
|
processor=None,
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_property_mappings: Dict[str, str] | None = None,
|
message_property_mappings: dict[str, str] | None = None,
|
||||||
message_field_training: str | None = None,
|
message_field_training: str | None = None,
|
||||||
message_field_training_detail: str | None = None,
|
message_field_training_detail: str | None = None,
|
||||||
field_messages: str = "messages",
|
field_messages: str = "messages",
|
||||||
field_system: str = "system",
|
field_system: str = "system",
|
||||||
roles: Dict[str, List[str]] | None = None,
|
roles: dict[str, list[str]] | None = None,
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
# check if message_property_mappings is None or empty dict
|
# check if message_property_mappings is None or empty dict
|
||||||
@@ -68,6 +69,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin | None = processor
|
self.processor: ProcessorMixin | None = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
|
self.chat_template_kwargs = chat_template_kwargs or {}
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
@@ -85,6 +87,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
|
**self.chat_template_kwargs,
|
||||||
)
|
)
|
||||||
batch = self.processor(
|
batch = self.processor(
|
||||||
text=text,
|
text=text,
|
||||||
@@ -103,6 +106,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
conversation,
|
conversation,
|
||||||
add_generation_prompt=add_generation_prompt,
|
add_generation_prompt=add_generation_prompt,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
|
**self.chat_template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_offsets_for_train_detail(
|
def get_offsets_for_train_detail(
|
||||||
@@ -779,6 +783,7 @@ class StrategyLoader:
|
|||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_template_string,
|
"chat_template": chat_template_string,
|
||||||
|
"chat_template_kwargs": cfg.get("chat_template_kwargs", {}),
|
||||||
"message_property_mappings": dataset_config.get(
|
"message_property_mappings": dataset_config.get(
|
||||||
"message_property_mappings", {}
|
"message_property_mappings", {}
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -314,6 +314,7 @@ class AxolotlInputConfig(
|
|||||||
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
|
||||||
) | None = None
|
) | None = None
|
||||||
chat_template_jinja: str | None = None
|
chat_template_jinja: str | None = None
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = None
|
||||||
eot_tokens: list[str] | None = None
|
eot_tokens: list[str] | None = None
|
||||||
default_system_message: str | None = None
|
default_system_message: str | None = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user