From 8a84408fc72f63ab5f37353022969d0803fe04b0 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 27 Aug 2024 04:25:44 +0530 Subject: [PATCH] Address review comments and add docs --- docs/config.qmd | 13 +- docs/dataset-formats/conversation.qmd | 149 ++++++++++++++++++ src/axolotl/utils/chat_templates.py | 11 +- .../config/models/input/v0_4_1/__init__.py | 12 +- .../test_chat_template_utils.py | 4 +- 5 files changed, 177 insertions(+), 12 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index e85999978..0e536d858 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -141,9 +141,16 @@ test_datasets: # use RL training: 'dpo', 'ipo', 'kto' rl: -# Saves the desired chat template to the tokenizer_config.json for easier inferencing -# Currently supports chatml and inst (mistral/mixtral) -chat_template: chatml +# The name of the chat template to use for training, following values are supported: +# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. +# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py +# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. +# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. +# The selected chat template will be saved to the tokenizer_config.json for easier inferencing +# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template. +chat_template: tokenizer_default +# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. +chat_template_jinja: null # Changes the default system message default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. # Axolotl attempts to save the dataset as an arrow after packing the data together so diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 28d13c987..8263c53af 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -69,3 +69,152 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f ```{.json filename="data.jsonl"} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} ``` + + +## chat_template + +Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Usually this chat template is stored in tokenizer_config.json under the key `chat_template`. + +Conversational data would normally look like follows: + +```{.json filename="data.jsonl"} +{"messages": [{"role": "...", "content": "..."}]} +``` + +with roles usually being system, user, assistant, etc. +However, all fields can be customized using the following configuration: + +```yaml +datasets: + - path: ... + # Set type to `chat_template` to use this strategy + type: chat_template + # Specify the name of the chat template to use + # The name of the chat template to use for training, following values are supported: + # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. + # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py + # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. + # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. + chat_template: tokenizer_default + # custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. + chat_template_jinja: null + # The key in the data example that contains the messages. Default is "conversations". + field_messages: conversations + # The key in the message turn that contains the role. Default is "from". + message_field_role: from + # The key in the message turn that contains the content. Default is "value". + message_field_content: value + # Role mapping for the messages. This can be useful if you are combining data from multiple sources and the roles are different. + roles: + human: user + user: user + assistant: assistant + gpt: assistant + system: system + # Roles to train on. The tokens from these roles will be considered for the loss. Default is ["gpt", "assistant"] + roles_to_train: ["gpt", "assistant"] + # Which EOS tokens to train on in the conversation. Possible values are: + # - all: train on all EOS tokens + # - turn: train on the EOS token at the end of each trainable turn + # - last: train on the last EOS token in the conversation + # - none: do not train on EOS tokens + # Default is "turn". + train_on_eos: turn + # The key in the message turn that indicates if tokens of a turn should be considered for training. This is an advanced option useful to selectively train on certain turns besides the `roles_to_train`. Default is "training". + message_field_training: training + # The key in the message turn that contains the training details. This is an advanced option useful to selectively train on certain tokens in a turn. Default is "train_detail". + message_field_training_detail: train_detail +``` + +### Examples + +1. Using the default chat template in the tokenizer_config.json on OpenAI messages format + +```yaml +datasets: + - path: ... + type: chat_template + chat_template: tokenizer_default + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: user + assistant: assistant + human: user + gpt: assistant + system: system + roles_to_train: ["assistant"] +``` + +2. Using a custom jinja template on OpenAI messages format + +```yaml +datasets: + - path: ... + type: chat_template + chat_template: jinja + chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}" + field_messages: messages + message_field_role: role + message_field_content: content + roles: + user: user + assistant: assistant + human: user + gpt: assistant + system: system + roles_to_train: ["assistant"] +``` + +3. Using fine-grained control over tokens and turns to train in a conversation + + +For a data sample that looks like: + +```{.json filename="data.jsonl"} +{ + "conversations": [ + {"from": "system", "value": "You are an AI assistant.", "train": false}, + {"from": "human", "value": "Hello", "train": false}, + {"from": "assistant", "value": "Hello", "train": true}, + {"from": "human", "value": "How are you?", "train": true}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": false}, + {"begin_offset": 9, "end_offset": 18, "train": true}, + {"begin_offset": 19, "end_offset": 30, "train": false}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": true, + }, + {"from": "assistant", "value": "Hi there!", "train": true} + ] +} +``` + +The configuration would look like: + +```yaml +datasets: + - path: ... + chat_template: tokenizer_default + field_messages: conversations + message_field_role: from + message_field_content: value + roles: + human: human + user: human + assistant: assistant + gpt: assistant + system: system + roles_to_train: [] + train_on_eos: turn + message_field_training: train + message_field_training_detail: train_detail +``` diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 55d70fd9e..f96b7a52f 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -14,7 +14,7 @@ _JINJA_TEMPALTE_CHOICE = "jinja" _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" _DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" -_TEMPLATES = { +_CHAT_TEMPLATES = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", @@ -78,18 +78,18 @@ def get_chat_template( f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." ) - if user_choice in _TEMPLATES: - return _TEMPLATES[user_choice] + if user_choice in _CHAT_TEMPLATES: + return _CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None): if ds_cfg and ds_cfg.get("chat_template"): - chat_template_choice = ds_cfg.get("chat_template") or "chatml" + chat_template_choice = ds_cfg.get("chat_template") or "tokenizer_default" chat_template_jinja = ds_cfg.get("chat_template_jinja") else: - chat_template_choice = cfg.get("chat_template") or "chatml" + chat_template_choice = cfg.get("chat_template") or "tokenizer_default" chat_template_jinja = cfg.get("chat_template_jinja") return chat_template_choice, chat_template_jinja @@ -99,7 +99,6 @@ def get_chat_template_from_config( ds_cfg: Optional[Dict[str, Any]] = None, tokenizer: Optional["PreTrainedTokenizerBase"] = None, ) -> str: - ds_cfg = ds_cfg or {} chat_template_choice, chat_template_jinja = extract_chat_template_args( cfg=cfg, ds_cfg=ds_cfg ) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 89cd36784..f147c645b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -130,7 +130,7 @@ class SFTDataset(BaseModel): chat_template: Union[ ChatTemplate, Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], - ] = ChatTemplate.chatml + ] = ChatTemplate.tokenizer_default chat_template_jinja: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None name: Optional[str] = None @@ -153,6 +153,7 @@ class SFTDataset(BaseModel): @model_validator(mode="before") @classmethod def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required if data.get("chat_template") == ChatTemplate.jinja and not data.get( "chat_template_jinja" ): @@ -160,6 +161,10 @@ class SFTDataset(BaseModel): "chat_template_jinja is required when chat_template is set to jinja" ) + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + return data @@ -815,6 +820,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required if data.get("chat_template") == ChatTemplate.jinja and not data.get( "chat_template_jinja" ): @@ -822,6 +828,10 @@ class AxolotlInputConfig( "chat_template_jinja is required when chat_template is set to jinja" ) + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + return data @model_validator(mode="before") diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py index e220ed13c..b63c9aa17 100644 --- a/tests/prompt_strategies/test_chat_template_utils.py +++ b/tests/prompt_strategies/test_chat_template_utils.py @@ -7,7 +7,7 @@ import pytest from transformers import AutoTokenizer from axolotl.utils.chat_templates import ( - _TEMPLATES, + _CHAT_TEMPLATES, extract_chat_template_args, get_chat_template, ) @@ -27,7 +27,7 @@ class TestGetChatTemplateUtils: def test_known_chat_template(self): chat_template_str = get_chat_template("llama3") - assert chat_template_str == _TEMPLATES["llama3"] + assert chat_template_str == _CHAT_TEMPLATES["llama3"] def test_invalid_chat_template(self): with pytest.raises(ValueError) as exc: