From 5edaad5b8be116085b9aefbb39a845a2b0309853 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 10 Jul 2024 02:12:34 +0530 Subject: [PATCH] Allow using tokenizer's default chat template with fallbacks Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests --- .../prompt_strategies/chat_template.py | 5 ++- .../prompt_strategies/orpo/chat_template.py | 6 ++-- src/axolotl/utils/chat_templates.py | 36 ++++++++++++++++++- .../config/models/input/v0_4_1/__init__.py | 20 +++++++++-- src/axolotl/utils/models.py | 2 +- 5 files changed, 60 insertions(+), 9 deletions(-) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 8c7a8dd4f..1626cc456 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -122,10 +122,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): else False ) + chat_template_str = chat_templates(chat_template, tokenizer=tokenizer) + LOG.info(f"Using chat template:\n---\n{chat_template_str!s}\n---") + strategy = ChatTemplateStrategy( ChatTemplatePrompter( tokenizer, - chat_templates(chat_template), + chat_template_str, message_field_role=message_field_role, message_field_content=message_field_content, roles=roles, diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index bba694856..f9e328ee6 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -33,7 +33,7 @@ def load( if ds_cfg and "chat_template" in ds_cfg: chat_template = ds_cfg["chat_template"] try: - chat_template = chat_templates(chat_template) + chat_template = chat_templates(chat_template, tokenizer=tokenizer) except ValueError: pass tokenizer.chat_template = chat_template @@ -248,11 +248,11 @@ class ORPOPrompter(Prompter): def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument dataset_parser = ORPODatasetParsingStrategy() - chat_template_str = chat_templates(cfg.chat_template) - def transform_fn(sample, tokenizer=None): res = {} + chat_template_str = chat_templates(cfg.chat_template, tokenizer=tokenizer) + res["prompt"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], add_generation_prompt=True, diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 725934cf5..d3e6e4d82 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,9 +2,15 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ +import logging + +LOG = logging.getLogger("axolotl.utils.chat_templates") + +_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" +_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" -def chat_templates(user_choice: str): +def chat_templates(user_choice: str, tokenizer=None): """ Finds the correct chat_template for the tokenizer_config. @@ -21,6 +27,7 @@ def chat_templates(user_choice: str): templates = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "mistral": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] | trim + '\n\n' %}{% set messages = messages[1:] %}{% else %}{% set system_message = '' %}{% endif %}{{ bos_token + system_message }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] | trim + eos_token }}{% endif %}{% endfor %}", "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", @@ -28,6 +35,33 @@ def chat_templates(user_choice: str): "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", } + if user_choice == _DEFAULT_TEMPLATE_CHOICE: + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" + ) + if not tokenizer.chat_template: + raise ValueError( + f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " + f"Please add a chat_template in tokenizer config" + ) + return tokenizer.chat_template + + if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" + ) + if tokenizer.chat_template: + return tokenizer.chat_template + + user_choice = user_choice[ + len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : + ] + LOG.warning( + f"No chat template found on tokenizer, falling back to {user_choice}" + ) + if user_choice in templates: return templates[user_choice] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3cac4f839..8845abe1b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -7,9 +7,16 @@ Module for pydantic models for configuration import logging import os from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union -from pydantic import BaseModel, Field, conlist, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + StringConstraints, + conlist, + field_validator, + model_validator, +) from transformers import SchedulerType from transformers.training_args import OptimizerNames @@ -179,6 +186,8 @@ class ChatTemplate(str, Enum): cohere = "cohere" # pylint: disable=invalid-name llama3 = "llama3" # pylint: disable=invalid-name phi_3 = "phi_3" # pylint: disable=invalid-name + mistral = "mistral" # pylint: disable=invalid-name + tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name class LoftQConfig(BaseModel): @@ -634,7 +643,12 @@ class AxolotlInputConfig( gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None - chat_template: Optional[ChatTemplate] = None + chat_template: Optional[ + Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] + ] = None default_system_message: Optional[str] = None # INTERNALS - document for now, generally not set externally diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d479d425d..f9ae2f0ce 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -283,7 +283,7 @@ def load_tokenizer(cfg): LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: - chat_template_string = chat_templates(cfg.chat_template) + chat_template_string = chat_templates(cfg.chat_template, tokenizer=tokenizer) if cfg.default_system_message and cfg.chat_template == "chatml": chat_template_string = chat_template_string.replace( "You are a helpful assistant.", cfg.default_system_message