Allow using tokenizer's default chat template with fallbacks

Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests
This commit is contained in:
Chirag Jain
2024-07-10 02:12:34 +05:30
parent 47e1916484
commit 5edaad5b8b
5 changed files with 60 additions and 9 deletions

View File

@@ -122,10 +122,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
else False
)
chat_template_str = chat_templates(chat_template, tokenizer=tokenizer)
LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---")
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
chat_template_str,
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,

View File

@@ -33,7 +33,7 @@ def load(
if ds_cfg and "chat_template" in ds_cfg:
chat_template = ds_cfg["chat_template"]
try:
chat_template = chat_templates(chat_template)
chat_template = chat_templates(chat_template, tokenizer=tokenizer)
except ValueError:
pass
tokenizer.chat_template = chat_template
@@ -248,11 +248,11 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None):
res = {}
chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer)
res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True,

View File

@@ -2,9 +2,15 @@
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""
import logging
LOG = logging.getLogger("axolotl.utils.chat_templates")
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
def chat_templates(user_choice: str):
def chat_templates(user_choice: str, tokenizer=None):
"""
Finds the correct chat_template for the tokenizer_config.
@@ -21,6 +27,7 @@ def chat_templates(user_choice: str):
templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}",
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
@@ -28,6 +35,33 @@ def chat_templates(user_choice: str):
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
}
if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}"
)
if user_choice in templates:
return templates[user_choice]

View File

@@ -7,9 +7,16 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
@@ -179,6 +186,8 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
mistral = "mistral" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -634,7 +643,12 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ChatTemplate] = None
chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
default_system_message: Optional[str] = None
# INTERNALS - document for now, generally not set externally

View File

@@ -283,7 +283,7 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template)
chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message