* adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
"""
|
|
tests for chat_template prompt strategy
|
|
"""
|
|
|
|
# pylint: disable=duplicate-code
|
|
import logging
|
|
import unittest
|
|
|
|
from axolotl.prompt_strategies.messages.chat import load
|
|
from axolotl.utils.dict import DictDefault
|
|
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
LOG = logging.getLogger("axolotl")
|
|
|
|
|
|
class TestMessagesChatLlama3:
|
|
"""
|
|
Test class for assistant style datasets with llama-3 prompts using the messages chat llama3 strategy.
|
|
"""
|
|
|
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
|
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
|
strategy = load(
|
|
llama3_tokenizer,
|
|
DictDefault(
|
|
{
|
|
"train_on_inputs": False,
|
|
"sequence_len": 512,
|
|
}
|
|
),
|
|
DictDefault(
|
|
{
|
|
"chat_template": "llama3",
|
|
"message_field_role": "role",
|
|
"message_field_content": "content",
|
|
"field_messages": "messages",
|
|
}
|
|
),
|
|
)
|
|
res = strategy.wrap_dataset(assistant_dataset)
|
|
input_ids = res[0]["input_ids"]
|
|
# fmt: off
|
|
expected_input_ids = [
|
|
128000, # bos
|
|
128006, 882, 128007, # user header
|
|
271, 15339, 128009, # user prompt eot
|
|
128006, 78191, 128007, # assistant header
|
|
271, 15339, 128009, # assistant response eot
|
|
128006, 882, 128007,
|
|
271, 19045, 29474, 128009,
|
|
128006, 78191, 128007,
|
|
271, 19045, 29474, 128009,
|
|
]
|
|
# fmt: on
|
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
|
assert (
|
|
input_ids == expected_input_ids
|
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|