* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
564 lines
25 KiB
Python
564 lines
25 KiB
Python
"""
|
|
tests for chat_template prompt strategy
|
|
"""
|
|
|
|
import logging
|
|
import unittest
|
|
|
|
from axolotl.prompt_strategies.chat_template import (
|
|
ChatTemplatePrompter,
|
|
ChatTemplateStrategy,
|
|
load,
|
|
)
|
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
|
from axolotl.utils.chat_templates import get_chat_template
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
LOG = logging.getLogger("axolotl")
|
|
|
|
|
|
class TestAssistantChatTemplateLlama3:
|
|
"""
|
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
|
"""
|
|
|
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
|
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
|
strategy = load(
|
|
llama3_tokenizer,
|
|
DictDefault(
|
|
{
|
|
"train_on_inputs": False,
|
|
"sequence_len": 512,
|
|
}
|
|
),
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"message_field_role": "role",
|
|
"message_field_content": "content",
|
|
"message_property_mappings": {
|
|
"role": "role",
|
|
"content": "content",
|
|
},
|
|
"roles": {
|
|
"user": ["user"],
|
|
"assistant": ["assistant"],
|
|
"system": ["system"],
|
|
},
|
|
"field_messages": "messages",
|
|
}
|
|
),
|
|
)
|
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
# fmt: on
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
|
LOG.info("Testing llama-3 with assistant dataset")
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template("llama3"),
|
|
message_property_mappings={
|
|
"role": "role",
|
|
"content": "content",
|
|
},
|
|
roles={
|
|
"user": ["user"],
|
|
"assistant": ["assistant"],
|
|
"system": ["system"],
|
|
},
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
sequence_len=512,
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
# fmt: on
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
|
LOG.info("Testing phi-3.5 with assistant dataset")
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
phi35_tokenizer,
|
|
chat_template=get_chat_template("phi_35"),
|
|
message_property_mappings={
|
|
"role": "role",
|
|
"content": "content",
|
|
},
|
|
roles={
|
|
"user": ["user"],
|
|
"assistant": ["assistant"],
|
|
"system": ["system"],
|
|
},
|
|
),
|
|
tokenizer=phi35_tokenizer,
|
|
train_on_inputs=False,
|
|
sequence_len=512,
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
32010, # user
|
|
22172, 32007, # user eot
|
|
32001, # assistant
|
|
22172, 32007, # assistant eot
|
|
32010, # user
|
|
1781, 26966, 32007, # user eot
|
|
32001, # assistant
|
|
1781, 26966, 32007, # assistant eot
|
|
]
|
|
expected_labels = [
|
|
-100, # user
|
|
-100, -100, # user eot
|
|
-100, # assistant
|
|
-100, -100, # assistant eot,
|
|
-100, # user
|
|
-100, -100, -100, # user eot
|
|
-100, # assistant
|
|
1781, 26966, 32007, # assistant eot
|
|
]
|
|
# fmt: on
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
LOG.debug(f"Expected labels : {expected_labels}")
|
|
LOG.debug(f"Actual labels : {labels}")
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Input IDs mismatch: {labels} != {expected_labels}"
|
|
|
|
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
|
LOG.info("Testing llama-3 with assistant dataset including training data")
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template("llama3"),
|
|
message_field_training="training",
|
|
message_property_mappings={
|
|
"role": "role",
|
|
"content": "content",
|
|
},
|
|
roles={
|
|
"user": ["user"],
|
|
"assistant": ["assistant"],
|
|
"system": ["system"],
|
|
},
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="none",
|
|
sequence_len=512,
|
|
roles_to_train=["assistant"],
|
|
)
|
|
|
|
prompt_tokens = strategy.prompter.build_prompt(
|
|
assistant_dataset[0]["messages"], False
|
|
)
|
|
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
|
|
LOG.debug(f"Generated prompt: {prompt}")
|
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
|
labels = res["labels"]
|
|
input_ids = res["input_ids"]
|
|
# fmt: off
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
|
]
|
|
# fmt: on
|
|
|
|
LOG.debug(f"Expected labels: {expected_labels}")
|
|
LOG.debug(f"Actual labels: {labels}")
|
|
assert labels == expected_labels, (
|
|
f"Labels mismatch:\n"
|
|
f"Expected: {expected_labels}\n"
|
|
f"Actual: {labels}\n"
|
|
f"Input IDs: {input_ids}\n"
|
|
)
|
|
|
|
|
|
class TestSharegptChatTemplateLlama3:
|
|
"""
|
|
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
|
"""
|
|
|
|
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
|
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
|
# pylint: disable=duplicate-code
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template("llama3"),
|
|
message_property_mappings={
|
|
"role": "from",
|
|
"content": "value",
|
|
},
|
|
field_messages="conversations",
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="none",
|
|
sequence_len=512,
|
|
roles_to_train=["gpt"],
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
|
]
|
|
# fmt: on
|
|
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
LOG.debug(f"Expected labels: {expected_labels}")
|
|
LOG.debug(f"Actual labels: {labels}")
|
|
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
|
|
|
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
|
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
|
# pylint: disable=duplicate-code
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template("llama3"),
|
|
message_property_mappings={
|
|
"role": "from",
|
|
"content": "value",
|
|
},
|
|
field_messages="conversations",
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="none",
|
|
sequence_len=512,
|
|
roles_to_train=["human"],
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
]
|
|
# fmt: on
|
|
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
LOG.debug(f"Expected labels: {expected_labels}")
|
|
LOG.debug(f"Actual labels: {labels}")
|
|
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
|
|
|
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
|
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
|
# pylint: disable=duplicate-code
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template("llama3"),
|
|
message_property_mappings={
|
|
"role": "from",
|
|
"content": "value",
|
|
},
|
|
field_messages="conversations",
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="none",
|
|
sequence_len=512,
|
|
roles_to_train=["system", "human"],
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 9125, 128007,
|
|
271, 2675, 527, 459, 15592, 18328, 13, 128009,
|
|
128006, 882, 128007, # user header
|
|
271, 9906, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 13347, 1070, 0, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 4438, 527, 499, 30, 128009,
|
|
128006, 78191, 128007,
|
|
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
|
|
]
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
|
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
|
]
|
|
# fmt: on
|
|
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
LOG.debug(f"Expected labels: {expected_labels}")
|
|
LOG.debug(f"Actual labels: {labels}")
|
|
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
|
|
|
|
|
class TestAssistantToolCallingChatTemplateLlama32Vision:
|
|
"""
|
|
Test class for assistant style datasets with tool_calling prompts using the llama-32_vision chat template.
|
|
"""
|
|
|
|
def test_llama32vision_train_on_assistant(
|
|
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
|
):
|
|
LOG.info(
|
|
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on assistant"
|
|
)
|
|
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template(
|
|
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
|
),
|
|
message_property_mappings={"role": "role", "content": "content"},
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="turn",
|
|
sequence_len=512,
|
|
roles_to_train=["assistant"],
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(toolcalling_dataset[0])
|
|
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 9125, 128007, 271, # system header
|
|
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
|
|
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
|
|
128006, 882, 128007, 271, # user header
|
|
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
|
|
128006, 78191, 128007, 271, # assistant header
|
|
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
|
128006, 23799, 4690, 128007, 271, # tool header
|
|
1, 1313, 13, 15, 1, 128009, # tool message
|
|
128006, 78191, 128007, 271, # assistant header
|
|
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
|
]
|
|
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
|
]
|
|
# fmt: on
|
|
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
|
|
|
def test_llama32vision_train_on_tools(
|
|
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
|
):
|
|
LOG.info(
|
|
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
|
|
)
|
|
# pylint: disable=duplicate-code
|
|
|
|
strategy = ChatTemplateStrategy(
|
|
ChatTemplatePrompter(
|
|
llama3_tokenizer,
|
|
chat_template=get_chat_template(
|
|
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
|
),
|
|
message_property_mappings={"role": "role", "content": "content"},
|
|
),
|
|
tokenizer=llama3_tokenizer,
|
|
train_on_inputs=False,
|
|
train_on_eos="turn",
|
|
sequence_len=512,
|
|
roles_to_train=["assistant", "tool"],
|
|
)
|
|
|
|
res = strategy.tokenize_prompt(toolcalling_dataset[0])
|
|
|
|
input_ids = res["input_ids"]
|
|
labels = res["labels"]
|
|
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 9125, 128007, 271, # system header
|
|
38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1114, 3799, 220, 2366, 19, 271, # system date prompt
|
|
2675, 527, 264, 11164, 430, 31680, 311, 9282, 20126, 13, 1472, 1288, 10052, 449, 279, 5089, 1511, 304, 279, 79002, 3813, 13, 128009, # system message
|
|
128006, 882, 128007, 271, # user header
|
|
19182, 11, 1148, 596, 279, 9499, 304, 12366, 1314, 1457, 30, 128009, # user message
|
|
128006, 78191, 128007, 271, # assistant header
|
|
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
|
128006, 23799, 4690, 128007, 271, # tool header
|
|
1, 1313, 13, 15, 1, 128009, # tool message
|
|
128006, 78191, 128007, 271, # assistant header
|
|
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
|
]
|
|
|
|
expected_labels = [
|
|
IGNORE_TOKEN_ID, # bos
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system date prompt
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
5018, 609, 794, 330, 456, 11327, 54625, 498, 330, 14105, 794, 5324, 2588, 794, 330, 60704, 11, 9822, 498, 330, 3928, 794, 330, 66, 41347, 32075, 128009, # assistant message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # tool header
|
|
IGNORE_TOKEN_ID, 1313, 13, 15, IGNORE_TOKEN_ID, 128009, # tool message
|
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
|
791, 9499, 304, 12366, 374, 220, 1313, 13, 15, 12628, 62447, 13, 128009 # assistant message
|
|
]
|
|
# fmt: on
|
|
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
assert (
|
|
labels == expected_labels
|
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|