Feat: Add support for tokenizer’s or custom jinja chat_template (#1970)
* Allow using tokenizer's default chat template with fallbacks Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests * Add tests * Fix lint and bug post merge from main * Add option `chat_template_jinja` to provide a jinja template * remove custom mistral template * Address review comments and add docs * Update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * fix: set default to tokenizer template * Merge branch 'main' into cj_tokenizer_default_prompt_template * chore: remove redundant function * fix: re-arrange enum declaration position * fix: refactor artifact left from main merge * feat(doc): updated config with chat template options and clarified examples * chore: clarify doc * chore: added example for non-default template * chore: refactor * fix: test * fix: config being dropped and unittest to catch that * chore: lint * chore: skip duplicate * fix: rename var after merge * feat: add test for levy's dpo case * fix: remove default setting on edge case where chat template overriden in dataset section * feat: handle sharegpt deprecation better in docs * feat: add example using fallback * feat: handles chat_template requiring specific user/assistant order * fix: update test based on new defaults * fix: imported name incorrectly updated on merge * chore: lint * fix: update dummy message to prevent potential overlap with real content * fix(doc): formatting * fix: update bradleyterry to use new chat_template --------- Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
This commit is contained in:
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
125
tests/prompt_strategies/test_chat_template_utils.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Tests for utils in axolotl.utils.chat_templates
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.chat_templates import (
|
||||
_CHAT_TEMPLATES,
|
||||
extract_chat_template_args,
|
||||
get_chat_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestGetChatTemplateUtils:
|
||||
"""
|
||||
Tests the get_chat_template function.
|
||||
"""
|
||||
|
||||
def test_known_chat_template(self):
|
||||
chat_template_str = get_chat_template("llama3")
|
||||
assert chat_template_str == _CHAT_TEMPLATES["llama3"]
|
||||
|
||||
def test_invalid_chat_template(self):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
get_chat_template("invalid_template")
|
||||
assert str(exc) == "Template 'invalid_template' not found."
|
||||
|
||||
def test_tokenizer_default_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer)
|
||||
|
||||
def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_tokenizer_default_fallback_no_tokenizer(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("tokenizer_default_fallback_test", tokenizer=None)
|
||||
|
||||
def test_tokenizer_default_fallback_no_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == get_chat_template("chatml")
|
||||
|
||||
def test_tokenizer_default_fallback_with_chat_template_on_tokenizer(
|
||||
self, llama3_tokenizer
|
||||
):
|
||||
llama3_tokenizer.chat_template = "test_template"
|
||||
chat_template_str = get_chat_template(
|
||||
"tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer
|
||||
)
|
||||
assert chat_template_str == "test_template"
|
||||
|
||||
def test_jinja_template_mode(self):
|
||||
jinja_template = "example_jinja_template"
|
||||
chat_template_str = get_chat_template("jinja", jinja_template=jinja_template)
|
||||
assert chat_template_str == jinja_template
|
||||
|
||||
def test_jinja_template_mode_no_jinja_template(self):
|
||||
with pytest.raises(ValueError):
|
||||
get_chat_template("jinja", jinja_template=None)
|
||||
|
||||
def test_extract_chat_template_args(self):
|
||||
# No ds_cfg
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={"chat_template": "chatml"},
|
||||
)
|
||||
assert chat_template_choice == "chatml"
|
||||
assert chat_template_jinja is None
|
||||
|
||||
# ds_cfg provided
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "global_jinja_template",
|
||||
},
|
||||
ds_cfg={"chat_template": "llama3", "chat_template_jinja": None},
|
||||
)
|
||||
assert chat_template_choice == "llama3"
|
||||
assert chat_template_jinja is None
|
||||
|
||||
# ds_cfg provided with jinja template
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={"chat_template": "chatml", "chat_template_jinja": None},
|
||||
ds_cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "ds_jinja_template",
|
||||
},
|
||||
)
|
||||
assert chat_template_choice == "jinja"
|
||||
assert chat_template_jinja == "ds_jinja_template"
|
||||
|
||||
# ds_cfg provided with no chat_template
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg={
|
||||
"chat_template": "jinja",
|
||||
"chat_template_jinja": "global_jinja_template",
|
||||
},
|
||||
ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"},
|
||||
)
|
||||
assert chat_template_choice == "jinja"
|
||||
assert chat_template_jinja == "global_jinja_template"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -11,7 +11,7 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
load,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@@ -73,7 +73,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=chat_templates("llama3"),
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
@@ -113,7 +113,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
phi35_tokenizer,
|
||||
chat_template=chat_templates("phi_35"),
|
||||
chat_template=get_chat_template("phi_35"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
roles={
|
||||
@@ -171,7 +171,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=chat_templates("llama3"),
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
@@ -230,7 +230,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -283,7 +283,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -336,7 +336,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
|
||||
@@ -12,7 +12,7 @@ from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplateStrategy,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -35,7 +35,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
@@ -80,7 +80,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_inputs=False")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -123,7 +123,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing roles_to_train with assistant only")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -151,7 +151,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing roles_to_train with all roles")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=True,
|
||||
@@ -184,7 +184,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with empty roles_to_train")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -205,7 +205,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_eos='all'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -232,7 +232,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_eos='turn'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -282,7 +282,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_eos='last'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -315,7 +315,7 @@ class TestChatTemplateConfigurations:
|
||||
LOG.info("Testing with train_on_eos='none'")
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer, chat_template=chat_templates("llama3")
|
||||
llama3_tokenizer, chat_template=get_chat_template("llama3")
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -343,7 +343,7 @@ class TestChatTemplateConfigurations:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=chat_templates("llama3"),
|
||||
chat_template=get_chat_template("llama3"),
|
||||
drop_system_message=True,
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
@@ -371,7 +371,7 @@ class TestChatTemplateConfigurations:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=chat_templates("llama3"),
|
||||
chat_template=get_chat_template("llama3"),
|
||||
roles=custom_roles,
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
@@ -424,7 +424,7 @@ class TestChatTemplateConfigurations:
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=chat_templates("llama3"),
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
),
|
||||
|
||||
@@ -86,6 +86,20 @@ def fixture_llama3_tokenizer():
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="phi3_tokenizer")
|
||||
def fixture_phi3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="gemma_tokenizer")
|
||||
def fixture_gemma_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplateLlama3:
|
||||
"""
|
||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||
@@ -99,7 +113,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"type": "chat_template",
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -124,7 +138,7 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "better",
|
||||
"field_rejected": "worse",
|
||||
@@ -152,5 +166,65 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplatePhi3:
|
||||
"""
|
||||
Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<|user|>\nhello<|end|>\n"
|
||||
+ "<|assistant|>\nhello<|end|>\n"
|
||||
+ "<|user|>\ngoodbye<|end|>\n"
|
||||
+ "<|assistant|>\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<|end|>"
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
class TestAssistantDPOChatTemplateGemma:
|
||||
"""
|
||||
Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy.
|
||||
"""
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "tokenizer_default",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer)
|
||||
assert result["prompt"] == (
|
||||
"<bos><start_of_turn>user\nhello<end_of_turn>\n"
|
||||
+ "<start_of_turn>model\nhello<end_of_turn>\n"
|
||||
+ "<start_of_turn>user\ngoodbye<end_of_turn>\n"
|
||||
+ "<start_of_turn>model\n"
|
||||
)
|
||||
assert result["chosen"] == "goodbye<end_of_turn>"
|
||||
assert result["rejected"] == "party on<end_of_turn>"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
238
tests/test_validation_dataset.py
Normal file
238
tests/test_validation_dataset.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Module for testing the validation module for the dataset config"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods (duplicate-code)
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"""
|
||||
Test the validation for the dataset config to ensure no correct parameters are dropped
|
||||
"""
|
||||
|
||||
def test_dataset_config_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "sharegpt",
|
||||
"conversation": "chatml",
|
||||
"shards": 10,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template is None
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template == ChatTemplate.chatml
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"chat_template": "gemma",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template == cfg.chat_template
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
Reference in New Issue
Block a user