Compare commits

...

46 Commits

Author SHA1 Message Date
NanoCode012
28e7e444ee fix: update bradleyterry to use new chat_template 2024-10-16 20:42:14 +07:00
NanoCode012
207e7627f9 fix(doc): formatting 2024-10-15 00:41:50 +07:00
NanoCode012
7eb62ae5a9 fix: update dummy message to prevent potential overlap with real content 2024-10-14 23:50:35 +07:00
NanoCode012
95805cf850 chore: lint 2024-10-14 23:43:30 +07:00
NanoCode012
4aafb7e600 fix: imported name incorrectly updated on merge 2024-10-14 23:41:17 +07:00
NanoCode012
17bc4c8b36 fix: update test based on new defaults 2024-10-14 18:03:35 +07:00
NanoCode012
d101cfc125 feat: handles chat_template requiring specific user/assistant order 2024-10-14 14:00:55 +07:00
NanoCode012
e5cd55cff9 feat: add example using fallback 2024-10-14 12:22:22 +07:00
NanoCode012
24aa6b15a0 feat: handle sharegpt deprecation better in docs 2024-10-14 12:21:58 +07:00
NanoCode012
9dfc5fa8b8 fix: remove default setting on edge case where chat template overriden in dataset section 2024-10-14 11:48:40 +07:00
NanoCode012
0c3255288f Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-14 10:36:08 +07:00
Chirag Jain
82b5dc9328 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-13 16:27:10 +05:30
Chirag Jain
ec57918fcd Merge pull request #7 from NanoCode012/cj_tokenizer_default_prompt_template
Feat: merge latest, update docs, fix dropped config bug, added unit test
2024-10-11 14:44:25 +05:30
NanoCode012
dd87d8c438 feat: add test for levy's dpo case 2024-10-11 12:56:46 +07:00
NanoCode012
ef942b6efc fix: rename var after merge 2024-10-11 12:30:43 +07:00
NanoCode012
3c6a6c61be Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-11 12:29:34 +07:00
NanoCode012
7b4b665e99 chore: skip duplicate 2024-10-11 11:42:36 +07:00
NanoCode012
21326e4ef3 chore: lint 2024-10-11 11:40:42 +07:00
NanoCode012
de23dab4fc fix: config being dropped and unittest to catch that 2024-10-11 11:40:32 +07:00
NanoCode012
e3efa29cf5 fix: test 2024-10-11 11:11:19 +07:00
NanoCode012
2038255052 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-10 20:25:37 +07:00
NanoCode012
dab2590e4d chore: refactor 2024-10-10 18:07:00 +07:00
NanoCode012
e5162b7a41 chore: added example for non-default template 2024-10-10 18:04:33 +07:00
NanoCode012
b6321d2220 chore: clarify doc 2024-10-10 18:01:33 +07:00
NanoCode012
6b3cdfdb8e feat(doc): updated config with chat template options and clarified examples 2024-10-10 17:57:11 +07:00
NanoCode012
203ae28704 fix: refactor artifact left from main merge 2024-10-10 17:16:41 +07:00
NanoCode012
ed3a33c9fb fix: re-arrange enum declaration position 2024-10-10 16:18:15 +07:00
NanoCode012
f61e2fc7dc chore: remove redundant function 2024-10-10 16:15:15 +07:00
NanoCode012
b8056d04d9 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-10-10 16:11:07 +07:00
NanoCode012
88658c0570 fix: set default to tokenizer template 2024-10-10 15:38:19 +07:00
Chirag Jain
260ca97f2c Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-09-13 00:33:49 +05:30
Chirag Jain
b1bb2accb9 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-28 13:34:20 +05:30
Chirag Jain
efeaa00bb4 Update docs/dataset-formats/conversation.qmd
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-08-27 19:08:54 +05:30
Chirag Jain
8a84408fc7 Address review comments and add docs 2024-08-27 04:30:35 +05:30
Chirag Jain
4805f3ca0a Merge branch 'main' of https://github.com/OpenAccess-AI-Collective/axolotl into cj_tokenizer_default_prompt_template 2024-08-27 02:35:58 +05:30
Chirag Jain
8ee30f5954 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-23 03:44:25 +05:30
Chirag Jain
6ef76f1ace remove custom mistral template 2024-08-19 15:56:47 +05:30
Chirag Jain
2e758aed6f Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-19 15:52:04 +05:30
Chirag Jain
21a2302538 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-12 10:24:02 +05:30
Chirag Jain
89f382a13a Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-08-06 21:23:14 +05:30
Chirag Jain
eb188acbd4 Add option chat_template_jinja to provide a jinja template 2024-07-31 01:43:40 +05:30
Chirag Jain
34ea51dcf3 Fix lint and bug post merge from main 2024-07-30 23:59:38 +05:30
Chirag Jain
fd7538dca7 Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-07-30 23:48:43 +05:30
Chirag Jain
99b3bc7fbd Merge branch 'main' into cj_tokenizer_default_prompt_template 2024-07-23 17:16:49 +05:30
Chirag Jain
4e38cea6b8 Add tests 2024-07-12 09:04:59 +05:30
Chirag Jain
5edaad5b8b Allow using tokenizer's default chat template with fallbacks
Summary of changes:

1. Adds `tokenizer_default` as option for `chat_template` in
   `chat_template` prompt strategy that allows using the chat template
   from tokenizer's config.json
2. Allows falling back to chat templates available in axolotl if
   tokenizer does not have a chat template
3. Adds a mistral chat template which supports system message - taken
   from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja

---

Why?

Many popular models are not trained with chatml format. As a result for
the model to correctly learn chatml we have to turn on train_on_inputs
which requires more compute and time. If we can use the model's already
learned chat template we can just learn the output tokens

---

Todo:

- Write tests
2024-07-12 08:42:26 +05:30
20 changed files with 900 additions and 118 deletions

View File

@@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation (deprecation soon, use chat_template) # fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ... - path: ...
type: sharegpt type: sharegpt

View File

@@ -83,7 +83,7 @@ lora_on_cpu: true
datasets: datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
@@ -124,6 +124,48 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# Using chat template
- path: ...
# Set type to `chat_template` to use this strategy
type: chat_template
# Specify the name of the chat template to use
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
chat_template: tokenizer_default
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
chat_template_jinja:
# The key in the data example that contains the messages. Default is "messages".
field_messages: messages
# The key in the message turn that contains the role. Default is "role".
message_field_role: role
# The key in the message turn that contains the content. Default is "content".
message_field_content: content
# Optional[Dict[str, List]]. Roles mapping for the messages.
roles:
user: ["human", "user"]
assistant: ["gpt", "assistant", "ai"]
system: ["system"]
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
roles_to_train: ["gpt", "assistant"]
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
# - all: train on all EOS tokens
# - turn: train on the EOS token at the end of each trainable turn
# - last: train on the last EOS token in the conversation
train_on_eos: last
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
message_field_training: training
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
# See example at `docs/dataset-formats/conversation.qmd`
message_field_training_detail: train_detail
# If false, the datasets will not be shuffled and will keep their original order in `datasets`. # If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true shuffle_merged_datasets: true
@@ -142,9 +184,16 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto' # use RL training: 'dpo', 'ipo', 'kto'
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # The name of the chat template to use for training, following values are supported:
# Currently supports chatml and inst (mistral/mixtral) # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
chat_template: chatml # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
chat_template: tokenizer_default
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
chat_template_jinja: null
# Changes the default system message # Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so

View File

@@ -6,6 +6,8 @@ order: 3
## sharegpt ## sharegpt
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
@@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f
```{.json filename="data.jsonl"} ```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
``` ```
## chat_template
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "content": "..."}]}
```
See `config.qmd` for full configs and supported templates.
### Migrating from sharegpt
Most configs can be adapted as follows:
```yaml
# old
chat_template: chatml
datasets:
- path: ...
type: sharegpt
conversation: chatml
# new (if using tokenizer's chat_template)
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
datasets:
- path: ...
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
```
We recommend checking the below examples for other usecases.
### Examples
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
```yaml
datasets:
- path: ...
type: chat_template
```
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: gemma # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
```yaml
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
```yaml
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
datasets:
- path: ...
type: chat_template
roles_to_train: ["assistant"]
```
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
For a data sample that looks like:
```{.json filename="data.jsonl"}
{
"conversations": [
{"from": "system", "value": "You are an AI assistant.", "train": false},
{"from": "human", "value": "Hello", "train": false},
{"from": "assistant", "value": "Hello", "train": true},
{"from": "human", "value": "How are you?", "train": true},
{
"from": "assistant",
"value": "I'm doing very well, thank you!",
"train_detail": [
{"begin_offset": 0, "end_offset": 8, "train": false},
{"begin_offset": 9, "end_offset": 18, "train": true},
{"begin_offset": 19, "end_offset": 30, "train": false},
],
},
{
"from": "human",
"value": "I'm doing very well, thank you!",
"train": true,
},
{"from": "assistant", "value": "Hi there!", "train": true}
]
}
```
The configuration would look like:
```yaml
datasets:
- path: ...
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
message_field_training_detail: train_detail
```
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.

View File

@@ -30,7 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import ( from axolotl.utils.config import (
normalize_cfg_datasets, normalize_cfg_datasets,
@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter importlib.import_module("axolotl.prompters"), prompter
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype) model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -63,7 +63,7 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.callbacks.lisa import lisa_callback_factory from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import ( from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq, DataCollatorForSeq2Seq,
@@ -1556,7 +1556,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template: if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates( training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template self.cfg.chat_template
) )

View File

@@ -6,7 +6,7 @@ import logging
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies") LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
def load(strategy, tokenizer, cfg, ds_cfg): def load(strategy, tokenizer, cfg, ds_cfg):

View File

@@ -2,13 +2,18 @@
Bradley-Terry model with chat template prompt strategy. Bradley-Terry model with chat template prompt strategy.
""" """
import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from axolotl.prompt_strategies.chat_template import ( from axolotl.prompt_strategies.chat_template import (
ChatTemplatePrompter, ChatTemplatePrompter,
ChatTemplateStrategy, ChatTemplateStrategy,
) )
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
LOG.setLevel(logging.INFO)
class BTChatTemplateStrategy(ChatTemplateStrategy): class BTChatTemplateStrategy(ChatTemplateStrategy):
@@ -27,18 +32,24 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
if prompt["system"]: if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) prompt[self.messages].append(
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) {"role": "system", "content": prompt["system"]}
prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) )
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt) chosen_tokenized = super().tokenize_prompt(prompt)
self.messages = "rejected_messages" self.messages = "rejected_messages"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
prompt[self.messages] = [] prompt[self.messages] = []
if prompt["system"]: if prompt["system"]:
prompt[self.messages].append({"from": "system", "value": prompt["system"]}) prompt[self.messages].append(
prompt[self.messages].append({"from": "user", "value": prompt["input"]}) {"role": "system", "content": prompt["system"]}
prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) )
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt) rejected_tokenized = super().tokenize_prompt(prompt)
return { return {
@@ -53,15 +64,18 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), "chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "from"), "message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "value"), "message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", "training"), "message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get( "message_field_training_detail": ds_cfg.get(
"message_field_training_detail", "train_detail" "message_field_training_detail", None
), ),
"roles": ds_cfg.get("roles"), "roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False), "drop_system_message": ds_cfg.get("drop_system_message", False),
@@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strategy_params = { strategy_params = {
"train_on_inputs": cfg.train_on_inputs, "train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len, "sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), "roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"), "train_on_eos": ds_cfg.get("train_on_eos", None),
} }
strategy = BTChatTemplateStrategy( strategy = BTChatTemplateStrategy(

View File

@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template_from_config
# Configure the logger # Configure the logger
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -405,10 +405,14 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
ds_cfg = ds_cfg or {} ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = { prompter_params = {
"tokenizer": tokenizer, "tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), "chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"), "message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None), "message_field_training": ds_cfg.get("message_field_training", None),

View File

@@ -2,15 +2,16 @@
DPO prompt strategies for using tokenizer chat templates. DPO prompt strategies for using tokenizer chat templates.
""" """
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
def default( def default(
cfg, dataset_idx=0, **kwargs cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument ): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx] ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template) chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages") field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen") field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected") field_rejected = ds_cfg.get("field_rejected", "rejected")
@@ -30,6 +31,12 @@ def default(
role_map[source] = target role_map[source] = target
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
chat_template_string = get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
messages = sample[field_messages] messages = sample[field_messages]
messages = [ messages = [
{ {
@@ -46,28 +53,29 @@ def default(
"role": role_map[sample[field_rejected][field_message_role]], "role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content], "content": sample[field_rejected][field_message_content],
} }
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
result = {} result = {}
result["prompt"] = tokenizer.apply_chat_template( result["prompt"] = tokenizer.apply_chat_template(
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
) )
result["chosen"] = tokenizer.apply_chat_template( result["chosen"] = tokenizer.apply_chat_template(
[chosen], [dummy_user_message, chosen],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
) )
chosen_strip_index = result["chosen"].find(chosen["content"]) chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template( result["rejected"] = tokenizer.apply_chat_template(
[rejected], [dummy_user_message, rejected],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
) )
rejected_strip_index = result["rejected"].find(rejected["content"]) rejected_strip_index = result["rejected"].find(rejected["content"])

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
from axolotl.prompters import Prompter from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template_from_config
class Message(BaseModel): class Message(BaseModel):
@@ -28,18 +28,13 @@ def load(
""" """
chatml transforms for datasets with system, input, chosen, rejected chatml transforms for datasets with system, input, chosen, rejected
""" """
chat_template_string = get_chat_template_from_config(
chat_template = chat_templates("chatml") cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
if ds_cfg and "chat_template" in ds_cfg: )
chat_template = ds_cfg["chat_template"] tokenizer.chat_template = chat_template_string
try:
chat_template = chat_templates(chat_template)
except ValueError:
pass
tokenizer.chat_template = chat_template
return ORPOTokenizingStrategy( return ORPOTokenizingStrategy(
ORPOPrompter(chat_template, tokenizer), ORPOPrompter(chat_template_string, tokenizer),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
@@ -248,28 +243,30 @@ class ORPOPrompter(Prompter):
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
dataset_parser = ORPODatasetParsingStrategy() dataset_parser = ORPODatasetParsingStrategy()
chat_template_str = chat_templates(cfg.chat_template)
def transform_fn(sample, tokenizer=None): def transform_fn(sample, tokenizer=None):
res = {} res = {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, tokenizer=tokenizer
)
res["prompt"] = tokenizer.apply_chat_template( res["prompt"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
add_generation_prompt=True, add_generation_prompt=True,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
) )
prompt_str_len = len(res["prompt"]) prompt_str_len = len(res["prompt"])
res["chosen"] = tokenizer.apply_chat_template( res["chosen"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]
res["rejected"] = tokenizer.apply_chat_template( res["rejected"] = tokenizer.apply_chat_template(
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
add_generation_prompt=False, add_generation_prompt=False,
chat_template=chat_template_str, chat_template=chat_template_string,
tokenize=False, tokenize=False,
)[prompt_str_len:] )[prompt_str_len:]

View File

@@ -62,7 +62,7 @@ def build_loader(
): ):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning( LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
) )
conversation = ( conversation = (
ds_cfg["conversation"] ds_cfg["conversation"]

View File

@@ -2,8 +2,19 @@
This module provides functionality for selecting chat templates based on user choices. This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation. These templates are used for formatting messages in a conversation.
""" """
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional
CHAT_TEMPLATES = { if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase
LOG = logging.getLogger("axolotl.utils.chat_templates")
_JINJA_TEMPALTE_CHOICE = "jinja"
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"
_CHAT_TEMPLATES = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1...
"mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large...
@@ -21,12 +32,18 @@ CHAT_TEMPLATES = {
} }
def chat_templates(user_choice: str): def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
""" """
Finds the correct chat_template for the tokenizer_config. Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
Args: Args:
user_choice (str): The user's choice of template. user_choice (str): The user's choice of template.
jinja_template (Optional[str], optional): The jinja template string. Defaults to None.
tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None.
Returns: Returns:
str: The chosen template string. str: The chosen template string.
@@ -34,13 +51,71 @@ def chat_templates(user_choice: str):
Raises: Raises:
ValueError: If the user_choice is not found in the templates. ValueError: If the user_choice is not found in the templates.
""" """
if user_choice == _JINJA_TEMPALTE_CHOICE:
if not jinja_template:
raise ValueError(
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}"
)
return jinja_template
if user_choice in CHAT_TEMPLATES: if user_choice == _DEFAULT_TEMPLATE_CHOICE:
return CHAT_TEMPLATES[user_choice] if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
)
if user_choice in _CHAT_TEMPLATES:
return _CHAT_TEMPLATES[user_choice]
raise ValueError(f"Template '{user_choice}' not found.") raise ValueError(f"Template '{user_choice}' not found.")
def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None):
if ds_cfg and ds_cfg.get("chat_template"):
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = ds_cfg.get("chat_template_jinja")
else:
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = cfg.get("chat_template_jinja")
return chat_template_choice, chat_template_jinja
def get_chat_template_from_config(
cfg,
ds_cfg: Optional[Dict[str, Any]] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
return get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)
def register_chat_template(template_name: str, chat_template: str): def register_chat_template(template_name: str, chat_template: str):
""" """
Registers chat templates. Registers chat templates.
@@ -50,7 +125,7 @@ def register_chat_template(template_name: str, chat_template: str):
chat_template (str): The template string. chat_template (str): The template string.
""" """
if template_name in CHAT_TEMPLATES: if template_name in _CHAT_TEMPLATES:
raise ValueError(f"Template '{template_name}' already exists.") raise ValueError(f"Template '{template_name}' already exists.")
CHAT_TEMPLATES[template_name] = chat_template _CHAT_TEMPLATES[template_name] = chat_template

View File

@@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg):
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
) )
cfg.datasets[idx].chat_template = cfg.chat_template cfg.datasets[idx].chat_template = cfg.chat_template
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):

View File

@@ -8,9 +8,16 @@ import logging
import os import os
from enum import Enum from enum import Enum
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator from pydantic import (
BaseModel,
Field,
StringConstraints,
conlist,
field_validator,
model_validator,
)
from transformers import SchedulerType from transformers import SchedulerType
from transformers.training_args import OptimizerNames from transformers.training_args import OptimizerNames
@@ -21,6 +28,37 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel): class DeprecatedParameters(BaseModel):
"""configurations that are deprecated""" """configurations that are deprecated"""
@@ -105,13 +143,19 @@ class SFTDataset(BaseModel):
input_transform: Optional[str] = None input_transform: Optional[str] = None
shards: Optional[int] = None shards: Optional[int] = None
conversation: Optional[str] = None conversation: Optional[str] = None
chat_template: Optional[str] = None # Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
Union[
ChatTemplate,
str,
]
] = None
chat_template_jinja: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None data_files: Optional[Union[str, List[str]]] = None
input_format: Optional[str] = None input_format: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
ds_type: Optional[str] = None ds_type: Optional[str] = None
train_on_split: Optional[str] = None train_on_split: Optional[str] = None
field: Optional[str] = None field: Optional[str] = None
field_human: Optional[str] = None field_human: Optional[str] = None
field_model: Optional[str] = None field_model: Optional[str] = None
@@ -122,13 +166,32 @@ class SFTDataset(BaseModel):
message_field_training_detail: Optional[str] = None message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
revision: Optional[str] = None revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class UserDefinedDPOType(BaseModel): class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO""" """User defined typing for DPO"""
@@ -174,35 +237,6 @@ class KTODataset(BaseModel):
revision: Optional[str] = None revision: Optional[str] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
class LoftQConfig(BaseModel): class LoftQConfig(BaseModel):
"""LoftQ configuration subset""" """LoftQ configuration subset"""
@@ -718,7 +752,13 @@ class AxolotlInputConfig(
gpu_memory_limit: Optional[Union[int, str]] = None gpu_memory_limit: Optional[Union[int, str]] = None
low_cpu_mem_usage: Optional[bool] = None low_cpu_mem_usage: Optional[bool] = None
chat_template: Optional[ChatTemplate] = None chat_template: Optional[
Union[
ChatTemplate,
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
]
] = None
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None fix_untrained_tokens: Optional[bool] = None
@@ -827,6 +867,23 @@ class AxolotlInputConfig(
return data return data
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_sample_packing_wo_flash(cls, data): def check_sample_packing_wo_flash(cls, data):

View File

@@ -50,7 +50,7 @@ from axolotl.monkeypatch.multipack import (
) )
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import zero_only from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
@@ -293,7 +293,10 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template: if cfg.chat_template:
chat_template_string = chat_templates(cfg.chat_template) chat_template_string = get_chat_template_from_config(
cfg=cfg,
tokenizer=tokenizer,
)
if cfg.default_system_message and cfg.chat_template == "chatml": if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace( chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message "You are a helpful assistant.", cfg.default_system_message

View File

@@ -0,0 +1,125 @@
"""
Tests for utils in axolotl.utils.chat_templates
"""
import unittest
import pytest
from transformers import AutoTokenizer
from axolotl.utils.chat_templates import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
)
@pytest.fixture(name="llama3_tokenizer")
def fixture_llama3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
return tokenizer
class TestGetChatTemplateUtils:
"""
Tests the get_chat_template function.
"""
def test_known_chat_template(self):
chat_template_str = get_chat_template("llama3")
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
def test_invalid_chat_template(self):
with pytest.raises(ValueError) as exc:
get_chat_template("invalid_template")
assert str(exc) == "Template 'invalid_template' not found."
def test_tokenizer_default_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=None)
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_tokenizer_default_fallback_no_tokenizer(self):
with pytest.raises(ValueError):
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
self, llama3_tokenizer
):
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == get_chat_template("chatml")
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
self, llama3_tokenizer
):
llama3_tokenizer.chat_template = "test_template"
chat_template_str = get_chat_template(
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
)
assert chat_template_str == "test_template"
def test_jinja_template_mode(self):
jinja_template = "example_jinja_template"
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
assert chat_template_str == jinja_template
def test_jinja_template_mode_no_jinja_template(self):
with pytest.raises(ValueError):
get_chat_template("jinja", jinja_template=None)
def test_extract_chat_template_args(self):
# No ds_cfg
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml"},
)
assert chat_template_choice == "chatml"
assert chat_template_jinja is None
# ds_cfg provided
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
)
assert chat_template_choice == "llama3"
assert chat_template_jinja is None
# ds_cfg provided with jinja template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={"chat_template": "chatml", "chat_template_jinja": None},
ds_cfg={
"chat_template": "jinja",
"chat_template_jinja": "ds_jinja_template",
},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "ds_jinja_template"
# ds_cfg provided with no chat_template
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg={
"chat_template": "jinja",
"chat_template_jinja": "global_jinja_template",
},
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
)
assert chat_template_choice == "jinja"
assert chat_template_jinja == "global_jinja_template"
if __name__ == "__main__":
unittest.main()

View File

@@ -11,7 +11,7 @@ from axolotl.prompt_strategies.chat_template import (
load, load,
) )
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=chat_templates("llama3"), chat_template=get_chat_template("llama3"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
roles={ roles={
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
phi35_tokenizer, phi35_tokenizer,
chat_template=chat_templates("phi_35"), chat_template=get_chat_template("phi_35"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
roles={ roles={
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=chat_templates("llama3"), chat_template=get_chat_template("llama3"),
message_field_role="role", message_field_role="role",
message_field_content="content", message_field_content="content",
message_field_training="training", message_field_training="training",
@@ -230,7 +230,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -283,7 +283,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -336,7 +336,7 @@ class TestSharegptChatTemplateLlama3:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,

View File

@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
ChatTemplateStrategy, ChatTemplateStrategy,
) )
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
from axolotl.utils.chat_templates import chat_templates from axolotl.utils.chat_templates import get_chat_template
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=True") LOG.info("Testing with train_on_inputs=True")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=True, train_on_inputs=True,
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_inputs=False") LOG.info("Testing with train_on_inputs=False")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with assistant only") LOG.info("Testing roles_to_train with assistant only")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing roles_to_train with all roles") LOG.info("Testing roles_to_train with all roles")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=True, train_on_inputs=True,
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with empty roles_to_train") LOG.info("Testing with empty roles_to_train")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='all'") LOG.info("Testing with train_on_eos='all'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='turn'") LOG.info("Testing with train_on_eos='turn'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='last'") LOG.info("Testing with train_on_eos='last'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
LOG.info("Testing with train_on_eos='none'") LOG.info("Testing with train_on_eos='none'")
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, chat_template=chat_templates("llama3") llama3_tokenizer, chat_template=get_chat_template("llama3")
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
train_on_inputs=False, train_on_inputs=False,
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=chat_templates("llama3"), chat_template=get_chat_template("llama3"),
drop_system_message=True, drop_system_message=True,
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=chat_templates("llama3"), chat_template=get_chat_template("llama3"),
roles=custom_roles, roles=custom_roles,
), ),
tokenizer=llama3_tokenizer, tokenizer=llama3_tokenizer,
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
strategy = ChatTemplateStrategy( strategy = ChatTemplateStrategy(
ChatTemplatePrompter( ChatTemplatePrompter(
llama3_tokenizer, llama3_tokenizer,
chat_template=chat_templates("llama3"), chat_template=get_chat_template("llama3"),
message_field_training="train", message_field_training="train",
message_field_training_detail="train_detail", message_field_training_detail="train_detail",
), ),

View File

@@ -86,6 +86,20 @@ def fixture_llama3_tokenizer():
return tokenizer return tokenizer
@pytest.fixture(name="phi3_tokenizer")
def fixture_phi3_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
return tokenizer
@pytest.fixture(name="gemma_tokenizer")
def fixture_gemma_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
return tokenizer
class TestAssistantDPOChatTemplateLlama3: class TestAssistantDPOChatTemplateLlama3:
""" """
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
@@ -99,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [
{ {
"chat_template": "llama3", "type": "chat_template",
} }
], ],
} }
@@ -124,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3:
"chat_template": "llama3", "chat_template": "llama3",
"datasets": [ "datasets": [
{ {
"chat_template": "llama3", "type": "chat_template",
"field_messages": "conversation", "field_messages": "conversation",
"field_chosen": "better", "field_chosen": "better",
"field_rejected": "worse", "field_rejected": "worse",
@@ -152,5 +166,65 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>" assert result["rejected"] == "party on<|eot_id|>"
class TestAssistantDPOChatTemplatePhi3:
"""
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
assert result["prompt"] == (
"<|user|>\nhello<|end|>\n"
+ "<|assistant|>\nhello<|end|>\n"
+ "<|user|>\ngoodbye<|end|>\n"
+ "<|assistant|>\n"
)
assert result["chosen"] == "goodbye<|end|>"
assert result["rejected"] == "party on<|end|>"
class TestAssistantDPOChatTemplateGemma:
"""
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn = default(
DictDefault(
{
"chat_template": "tokenizer_default",
"datasets": [
{
"type": "chat_template",
}
],
}
)
)
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
assert result["prompt"] == (
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
+ "<start_of_turn>model\nhello<end_of_turn>\n"
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
+ "<start_of_turn>model\n"
)
assert result["chosen"] == "goodbye<end_of_turn>"
assert result["rejected"] == "party on<end_of_turn>"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -0,0 +1,238 @@
"""Module for testing the validation module for the dataset config"""
import warnings
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
from axolotl.utils.dict import DictDefault
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
# pylint: disable=too-many-public-methods (duplicate-code)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
class TestValidationCheckDatasetConfig(BaseValidation):
"""
Test the validation for the dataset config to ensure no correct parameters are dropped
"""
def test_dataset_config_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
"shards": 10,
}
]
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template is None
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == ChatTemplate.chatml
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"chat_template": "gemma",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == cfg.chat_template
assert (
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
)
_check_config()