Allow using tokenizer's default chat template with fallbacks
Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests
This commit is contained in:
@@ -122,10 +122,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
else False
|
||||
)
|
||||
|
||||
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
tokenizer,
|
||||
chat_templates(chat_template),
|
||||
chat_template_str,
|
||||
message_field_role=message_field_role,
|
||||
message_field_content=message_field_content,
|
||||
roles=roles,
|
||||
|
||||
@@ -33,7 +33,7 @@ def load(
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template)
|
||||
chat_template = chat_templates(chat_template, tokenizer=tokenizer)
|
||||
except ValueError:
|
||||
pass
|
||||
tokenizer.chat_template = chat_template
|
||||
@@ -248,11 +248,11 @@ class ORPOPrompter(Prompter):
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
|
||||
@@ -2,9 +2,15 @@
|
||||
This module provides functionality for selecting chat templates based on user choices.
|
||||
These templates are used for formatting messages in a conversation.
|
||||
"""
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.chat_templates")
|
||||
|
||||
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||
|
||||
|
||||
def chat_templates(user_choice: str):
|
||||
def chat_templates(user_choice: str, tokenizer=None):
|
||||
"""
|
||||
Finds the correct chat_template for the tokenizer_config.
|
||||
|
||||
@@ -21,6 +27,7 @@ def chat_templates(user_choice: str):
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
|
||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
||||
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
||||
@@ -28,6 +35,33 @@ def chat_templates(user_choice: str):
|
||||
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
||||
}
|
||||
|
||||
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
|
||||
if not tokenizer:
|
||||
raise ValueError(
|
||||
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
|
||||
)
|
||||
if not tokenizer.chat_template:
|
||||
raise ValueError(
|
||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||
f"Please add a chat_template in tokenizer config"
|
||||
)
|
||||
return tokenizer.chat_template
|
||||
|
||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||
if not tokenizer:
|
||||
raise ValueError(
|
||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||
)
|
||||
if tokenizer.chat_template:
|
||||
return tokenizer.chat_template
|
||||
|
||||
user_choice = user_choice[
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
]
|
||||
LOG.warning(
|
||||
f"No chat template found on tokenizer, falling back to {user_choice}"
|
||||
)
|
||||
|
||||
if user_choice in templates:
|
||||
return templates[user_choice]
|
||||
|
||||
|
||||
@@ -7,9 +7,16 @@ Module for pydantic models for configuration
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
StringConstraints,
|
||||
conlist,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
@@ -179,6 +186,8 @@ class ChatTemplate(str, Enum):
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
mistral = "mistral" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -634,7 +643,12 @@ class AxolotlInputConfig(
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
chat_template: Optional[ChatTemplate] = None
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
]
|
||||
] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
# INTERNALS - document for now, generally not set externally
|
||||
|
||||
@@ -283,7 +283,7 @@ def load_tokenizer(cfg):
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if cfg.chat_template:
|
||||
chat_template_string = chat_templates(cfg.chat_template)
|
||||
chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer)
|
||||
if cfg.default_system_message and cfg.chat_template == "chatml":
|
||||
chat_template_string = chat_template_string.replace(
|
||||
"You are a helpful assistant.", cfg.default_system_message
|
||||
|
||||
Reference in New Issue
Block a user