Address review comments and add docs
This commit is contained in:
@@ -141,9 +141,16 @@ test_datasets:
|
|||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# The name of the chat template to use for training, following values are supported:
|
||||||
# Currently supports chatml and inst (mistral/mixtral)
|
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||||
chat_template: chatml
|
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||||
|
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||||
|
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||||
|
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
||||||
|
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||||
|
chat_template_jinja: null
|
||||||
# Changes the default system message
|
# Changes the default system message
|
||||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
|
|||||||
@@ -69,3 +69,152 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
|
|||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## chat_template
|
||||||
|
|
||||||
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Usually this chat template is stored in tokenizer_config.json under the key `chat_template`.
|
||||||
|
|
||||||
|
Conversational data would normally look like follows:
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{"messages": [{"role": "...", "content": "..."}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
with roles usually being system, user, assistant, etc.
|
||||||
|
However, all fields can be customized using the following configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
# Set type to `chat_template` to use this strategy
|
||||||
|
type: chat_template
|
||||||
|
# Specify the name of the chat template to use
|
||||||
|
# The name of the chat template to use for training, following values are supported:
|
||||||
|
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||||
|
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||||
|
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||||
|
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||||
|
chat_template_jinja: null
|
||||||
|
# The key in the data example that contains the messages. Default is "conversations".
|
||||||
|
field_messages: conversations
|
||||||
|
# The key in the message turn that contains the role. Default is "from".
|
||||||
|
message_field_role: from
|
||||||
|
# The key in the message turn that contains the content. Default is "value".
|
||||||
|
message_field_content: value
|
||||||
|
# Role mapping for the messages. This can be useful if you are combining data from multiple sources and the roles are different.
|
||||||
|
roles:
|
||||||
|
human: user
|
||||||
|
user: user
|
||||||
|
assistant: assistant
|
||||||
|
gpt: assistant
|
||||||
|
system: system
|
||||||
|
# Roles to train on. The tokens from these roles will be considered for the loss. Default is ["gpt", "assistant"]
|
||||||
|
roles_to_train: ["gpt", "assistant"]
|
||||||
|
# Which EOS tokens to train on in the conversation. Possible values are:
|
||||||
|
# - all: train on all EOS tokens
|
||||||
|
# - turn: train on the EOS token at the end of each trainable turn
|
||||||
|
# - last: train on the last EOS token in the conversation
|
||||||
|
# - none: do not train on EOS tokens
|
||||||
|
# Default is "turn".
|
||||||
|
train_on_eos: turn
|
||||||
|
# The key in the message turn that indicates if tokens of a turn should be considered for training. This is an advanced option useful to selectively train on certain turns besides the `roles_to_train`. Default is "training".
|
||||||
|
message_field_training: training
|
||||||
|
# The key in the message turn that contains the training details. This is an advanced option useful to selectively train on certain tokens in a turn. Default is "train_detail".
|
||||||
|
message_field_training_detail: train_detail
|
||||||
|
```
|
||||||
|
|
||||||
|
### Examples
|
||||||
|
|
||||||
|
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user: user
|
||||||
|
assistant: assistant
|
||||||
|
human: user
|
||||||
|
gpt: assistant
|
||||||
|
system: system
|
||||||
|
roles_to_train: ["assistant"]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Using a custom jinja template on OpenAI messages format
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
chat_template: jinja
|
||||||
|
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
|
||||||
|
field_messages: messages
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
user: user
|
||||||
|
assistant: assistant
|
||||||
|
human: user
|
||||||
|
gpt: assistant
|
||||||
|
system: system
|
||||||
|
roles_to_train: ["assistant"]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Using fine-grained control over tokens and turns to train in a conversation
|
||||||
|
|
||||||
|
|
||||||
|
For a data sample that looks like:
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "You are an AI assistant.", "train": false},
|
||||||
|
{"from": "human", "value": "Hello", "train": false},
|
||||||
|
{"from": "assistant", "value": "Hello", "train": true},
|
||||||
|
{"from": "human", "value": "How are you?", "train": true},
|
||||||
|
{
|
||||||
|
"from": "assistant",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train_detail": [
|
||||||
|
{"begin_offset": 0, "end_offset": 8, "train": false},
|
||||||
|
{"begin_offset": 9, "end_offset": 18, "train": true},
|
||||||
|
{"begin_offset": 19, "end_offset": 30, "train": false},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train": true,
|
||||||
|
},
|
||||||
|
{"from": "assistant", "value": "Hi there!", "train": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The configuration would look like:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
chat_template: tokenizer_default
|
||||||
|
field_messages: conversations
|
||||||
|
message_field_role: from
|
||||||
|
message_field_content: value
|
||||||
|
roles:
|
||||||
|
human: human
|
||||||
|
user: human
|
||||||
|
assistant: assistant
|
||||||
|
gpt: assistant
|
||||||
|
system: system
|
||||||
|
roles_to_train: []
|
||||||
|
train_on_eos: turn
|
||||||
|
message_field_training: train
|
||||||
|
message_field_training_detail: train_detail
|
||||||
|
```
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ _JINJA_TEMPALTE_CHOICE = "jinja"
|
|||||||
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
|
||||||
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
|
||||||
|
|
||||||
_TEMPLATES = {
|
_CHAT_TEMPLATES = {
|
||||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
@@ -78,18 +78,18 @@ def get_chat_template(
|
|||||||
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_choice in _TEMPLATES:
|
if user_choice in _CHAT_TEMPLATES:
|
||||||
return _TEMPLATES[user_choice]
|
return _CHAT_TEMPLATES[user_choice]
|
||||||
|
|
||||||
raise ValueError(f"Template '{user_choice}' not found.")
|
raise ValueError(f"Template '{user_choice}' not found.")
|
||||||
|
|
||||||
|
|
||||||
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
if ds_cfg and ds_cfg.get("chat_template"):
|
if ds_cfg and ds_cfg.get("chat_template"):
|
||||||
chat_template_choice = ds_cfg.get("chat_template") or "chatml"
|
chat_template_choice = ds_cfg.get("chat_template") or "tokenizer_default"
|
||||||
chat_template_jinja = ds_cfg.get("chat_template_jinja")
|
chat_template_jinja = ds_cfg.get("chat_template_jinja")
|
||||||
else:
|
else:
|
||||||
chat_template_choice = cfg.get("chat_template") or "chatml"
|
chat_template_choice = cfg.get("chat_template") or "tokenizer_default"
|
||||||
chat_template_jinja = cfg.get("chat_template_jinja")
|
chat_template_jinja = cfg.get("chat_template_jinja")
|
||||||
return chat_template_choice, chat_template_jinja
|
return chat_template_choice, chat_template_jinja
|
||||||
|
|
||||||
@@ -99,7 +99,6 @@ def get_chat_template_from_config(
|
|||||||
ds_cfg: Optional[Dict[str, Any]] = None,
|
ds_cfg: Optional[Dict[str, Any]] = None,
|
||||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
ds_cfg = ds_cfg or {}
|
|
||||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
cfg=cfg, ds_cfg=ds_cfg
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class SFTDataset(BaseModel):
|
|||||||
chat_template: Union[
|
chat_template: Union[
|
||||||
ChatTemplate,
|
ChatTemplate,
|
||||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||||
] = ChatTemplate.chatml
|
] = ChatTemplate.tokenizer_default
|
||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
data_files: Optional[Union[str, List[str]]] = None
|
data_files: Optional[Union[str, List[str]]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -153,6 +153,7 @@ class SFTDataset(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_chat_template_config(cls, data):
|
def check_chat_template_config(cls, data):
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
"chat_template_jinja"
|
"chat_template_jinja"
|
||||||
):
|
):
|
||||||
@@ -160,6 +161,10 @@ class SFTDataset(BaseModel):
|
|||||||
"chat_template_jinja is required when chat_template is set to jinja"
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -815,6 +820,7 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_chat_template_config(cls, data):
|
def check_chat_template_config(cls, data):
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
"chat_template_jinja"
|
"chat_template_jinja"
|
||||||
):
|
):
|
||||||
@@ -822,6 +828,10 @@ class AxolotlInputConfig(
|
|||||||
"chat_template_jinja is required when chat_template is set to jinja"
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import pytest
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
_TEMPLATES,
|
_CHAT_TEMPLATES,
|
||||||
extract_chat_template_args,
|
extract_chat_template_args,
|
||||||
get_chat_template,
|
get_chat_template,
|
||||||
)
|
)
|
||||||
@@ -27,7 +27,7 @@ class TestGetChatTemplateUtils:
|
|||||||
|
|
||||||
def test_known_chat_template(self):
|
def test_known_chat_template(self):
|
||||||
chat_template_str = get_chat_template("llama3")
|
chat_template_str = get_chat_template("llama3")
|
||||||
assert chat_template_str == _TEMPLATES["llama3"]
|
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
|
||||||
|
|
||||||
def test_invalid_chat_template(self):
|
def test_invalid_chat_template(self):
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
|
|||||||
Reference in New Issue
Block a user