* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
207 lines
11 KiB
Python
207 lines
11 KiB
Python
"""
|
|
shared fixtures for prompt strategies tests
|
|
"""
|
|
|
|
import pytest
|
|
from datasets import Dataset
|
|
from huggingface_hub import hf_hub_download
|
|
from transformers import AutoTokenizer
|
|
|
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
|
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
|
|
|
|
|
@pytest.fixture(name="assistant_dataset")
|
|
def fixture_assistant_dataset():
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": [
|
|
{"role": "user", "content": "hello"},
|
|
{"role": "assistant", "content": "hello"},
|
|
{"role": "user", "content": "goodbye"},
|
|
{"role": "assistant", "content": "goodbye"},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="sharegpt_dataset")
|
|
def fixture_sharegpt_dataset():
|
|
# pylint: disable=duplicate-code
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversations": [
|
|
{"from": "human", "value": "hello"},
|
|
{"from": "gpt", "value": "hello"},
|
|
{"from": "human", "value": "goodbye"},
|
|
{"from": "gpt", "value": "goodbye"},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="basic_dataset")
|
|
def fixture_basic_dataset():
|
|
# pylint: disable=duplicate-code
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"conversations": [
|
|
{"from": "system", "value": "You are an AI assistant."},
|
|
{"from": "human", "value": "Hello"},
|
|
{"from": "assistant", "value": "Hi there!"},
|
|
{"from": "human", "value": "How are you?"},
|
|
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="toolcalling_dataset")
|
|
def fixture_toolcalling_dataset():
|
|
# pylint: disable=duplicate-code
|
|
return Dataset.from_list(
|
|
[
|
|
{
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": "Hey, what's the temperature in Paris right now?",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"tool_calls": [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_temperature",
|
|
"arguments": {
|
|
"location": "Paris, France",
|
|
"unit": "celsius",
|
|
},
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"name": "get_current_temperature",
|
|
"content": "22.0",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "The temperature in Paris is 22.0 degrees Celsius.",
|
|
},
|
|
]
|
|
}
|
|
]
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="llama3_tokenizer", scope="session", autouse=True)
|
|
def fixture_llama3_tokenizer():
|
|
hf_hub_download(
|
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
|
filename="special_tokens_map.json",
|
|
)
|
|
hf_hub_download(
|
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
|
filename="tokenizer_config.json",
|
|
)
|
|
hf_hub_download(
|
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="mistralv03_tokenizer", scope="session", autouse=True)
|
|
def fixture_mistralv03_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
|
)
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="phi35_tokenizer", scope="session", autouse=True)
|
|
def fixture_phi35_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="gemma2_tokenizer", scope="session", autouse=True)
|
|
def fixture_gemma2_tokenizer():
|
|
tokenizer = AutoTokenizer.from_pretrained("mlx-community/gemma-2-9b-it-4bit")
|
|
|
|
return tokenizer
|
|
|
|
|
|
@pytest.fixture(name="mistralv03_tokenizer_chat_template_jinja")
|
|
def fixture_mistralv03_chat_template_jinja_w_system() -> str:
|
|
return '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message["role"] == "user") != (ns.index % 2 == 0) %}\n {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message["role"] == "user" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- "[AVAILABLE_TOOLS] [" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- \'{"type": "function", "function": {\' }}\n {%- for key, val in tool.items() if key != "return" %}\n {%- if val is string %}\n {{- \'"\' + key + \'": "\' + val + \'"\' }}\n {%- else %}\n {{- \'"\' + key + \'": \' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- "}}" }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" }}\n {%- endif %}\n {%- endfor %}\n {{- "[/AVAILABLE_TOOLS]" }}\n {%- endif %}\n {%- if loop.first and system_message is defined %}\n {{- "[INST] " + system_message + "\\n\\n" + message["content"] + "[/INST]" }}\n {%- else %}\n {{- "[INST] " + message["content"] + "[/INST]" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- "[TOOL_CALLS] [" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \', "id": "\' + tool_call.id + \'"}\' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- else %}\n {{- "]" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message["role"] == "assistant" %}\n {{- " " + message["content"]|trim + eos_token}}\n {%- elif message["role"] == "tool_results" or message["role"] == "tool" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- \'[TOOL_RESULTS] {"content": \' + content|string + ", " }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}\n {%- endif %}\n {{- \'"call_id": "\' + message.tool_call_id + \'"}[/TOOL_RESULTS]\' }}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}\n'
|
|
|
|
|
|
@pytest.fixture(name="gemma2_tokenizer_chat_template_jinja")
|
|
def fixture_gemma2_chat_template_jinja_w_system() -> str:
|
|
return "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"
|
|
|
|
|
|
@pytest.fixture(name="llama3_2_vision_chat_template_jinja")
|
|
def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
|
"""Hardcodes the date in the template to avoid the need for date logic in the prompt"""
|
|
|
|
template = _CHAT_TEMPLATES["llama3_2_vision"]
|
|
|
|
old_date_logic = """{%- if not date_string is defined %}
|
|
{%- if strftime_now is defined %}
|
|
{%- set date_string = strftime_now("%d %b %Y") %}
|
|
{%- else %}
|
|
{%- set date_string = "26 Jul 2024" %}
|
|
{%- endif %}
|
|
{%- endif %}"""
|
|
|
|
new_date_logic = """{%- set date_string = "17 Dec 2024" %}"""
|
|
|
|
modified_template = template.replace(old_date_logic, new_date_logic)
|
|
|
|
return modified_template
|
|
|
|
|
|
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
|
|
def fixture_chat_template_jinja_with_optional_fields() -> str:
|
|
return """{% for message in messages %}
|
|
{{'<|im_start|>'}}{{ message['role'] }}
|
|
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
|
|
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
|
|
{{ message['content'] }}{{'<|im_end|>'}}
|
|
{% endfor %}"""
|
|
|
|
|
|
@pytest.fixture(name="basic_jinja_template_analyzer")
|
|
def basic_jinja_template_analyzer():
|
|
return JinjaTemplateAnalyzer(
|
|
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
|
|
' + message['content'] + '<|end|>
|
|
'}}{% elif message['role'] == 'user' %}{{'<|user|>
|
|
' + message['content'] + '<|end|>
|
|
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
|
|
' + message['content'] + '<|end|>
|
|
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
|
|
' }}{% else %}{{ eos_token }}{% endif %}"""
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="mistral_jinja_template_analyzer")
|
|
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
|
|
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)
|