feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -142,10 +142,19 @@ datasets:
|
||||
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
# Key for role in each message (default: "role")
|
||||
message_field_role: role
|
||||
# Key for content in each message (default: "content")
|
||||
message_field_content: content
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
# to load it directly from the message using the property name as the key.
|
||||
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||
# while 'value' is loaded and used as 'content' in the chat template.
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
|
||||
@@ -42,8 +42,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
@@ -52,8 +53,9 @@ datasets:
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
@@ -138,8 +140,9 @@ datasets:
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
|
||||
@@ -114,7 +114,7 @@ A flow chart is as follows:
|
||||
|
||||
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
||||
|
||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a Github Discussion.
|
||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
|
||||
|
||||
::: {.callout-tip}
|
||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||
@@ -289,9 +289,10 @@ If your dataset format is different, here are the keys you should check (with th
|
||||
```yaml
|
||||
datasets:
|
||||
...
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
field_messages: messages # this should point to the key containing the list of conversations
|
||||
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
|
||||
role: role
|
||||
content: content
|
||||
```
|
||||
|
||||
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
||||
@@ -348,13 +349,14 @@ datasets:
|
||||
- path: A.jsonl
|
||||
type: chat_template
|
||||
|
||||
# step 1
|
||||
# step 1
|
||||
chat_template: chatml
|
||||
|
||||
# step 2
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
# step 2
|
||||
field_messages: messages
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
roles:
|
||||
assistant:
|
||||
@@ -365,8 +367,8 @@ datasets:
|
||||
- human
|
||||
- user
|
||||
|
||||
# step 3
|
||||
roles_to_train: ["assistant"]
|
||||
# step 3
|
||||
roles_to_train: ["assistant"]
|
||||
train_on_eos: "turn"
|
||||
|
||||
special_tokens:
|
||||
|
||||
@@ -23,3 +23,7 @@ description: Frequently asked questions
|
||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||
|
||||
@@ -229,8 +229,9 @@ datasets:
|
||||
field_messages: "messages"
|
||||
field_chosen: "chosen"
|
||||
field_rejected: "rejected"
|
||||
message_field_role: "role"
|
||||
message_field_content: "content"
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user: ["user"]
|
||||
assistant: ["assistant"]
|
||||
|
||||
@@ -21,8 +21,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -16,8 +16,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -13,8 +13,9 @@ datasets:
|
||||
type: chat_template
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -17,8 +17,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
@@ -31,8 +32,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -22,8 +22,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
|
||||
@@ -14,8 +14,9 @@ datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
user:
|
||||
- user
|
||||
|
||||
@@ -12,8 +12,9 @@ datasets:
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
|
||||
@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features[field_messages][0].keys()
|
||||
message_field_role = None
|
||||
|
||||
message_property_mappings = {"role": None, "content": None}
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_field_role = key
|
||||
message_property_mappings["role"] = key
|
||||
break
|
||||
if not message_field_role:
|
||||
if not message_property_mappings["role"]:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_field_content = key
|
||||
message_property_mappings["content"] = key
|
||||
break
|
||||
if not message_field_content:
|
||||
if not message_property_mappings["content"]:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
ds_cfg["message_property_mappings"] = message_property_mappings
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
|
||||
@@ -34,15 +34,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
|
||||
max_length = self.prompter.max_length
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||
@@ -55,17 +52,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
:max_length
|
||||
]
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
prompt["messages"] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
|
||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||
|
||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_property_mappings": ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
chat_template: str,
|
||||
processor=None,
|
||||
chat_template=None,
|
||||
max_length=2048,
|
||||
message_field_role: str = "role",
|
||||
message_field_content: str = "content",
|
||||
message_property_mappings: Optional[Dict[str, str]] = None,
|
||||
message_field_training: Optional[str] = None,
|
||||
message_field_training_detail: Optional[str] = None,
|
||||
field_messages: str = "messages",
|
||||
roles: Optional[Dict[str, List[str]]] = None,
|
||||
drop_system_message: bool = False,
|
||||
):
|
||||
# check if message_property_mappings is None or empty dict
|
||||
if message_property_mappings is None or (not message_property_mappings):
|
||||
message_property_mappings = {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
if roles:
|
||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||
else:
|
||||
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
|
||||
"tool": "tool",
|
||||
}
|
||||
|
||||
self.message_field_role = message_field_role
|
||||
self.message_field_content = message_field_content
|
||||
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
||||
chat_template, field_messages
|
||||
)
|
||||
self.message_property_mappings = message_property_mappings
|
||||
self.message_field_training = message_field_training
|
||||
self.message_field_training_detail = message_field_training_detail
|
||||
self.field_messages = field_messages
|
||||
self.tokenizer = tokenizer
|
||||
self.processor: ProcessorMixin = processor
|
||||
self.processor: Optional[ProcessorMixin] = processor
|
||||
self.chat_template = chat_template
|
||||
self.max_length = max_length
|
||||
self.drop_system_message = drop_system_message
|
||||
|
||||
@property
|
||||
def chat_template_msg_variables(self) -> Set[str]:
|
||||
return self._chat_template_msg_variables
|
||||
|
||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||
if self.processor:
|
||||
if not callable(self.processor):
|
||||
raise TypeError("Processor must be callable")
|
||||
|
||||
text = self.processor.apply_chat_template(
|
||||
conversation,
|
||||
chat_template=self.chat_template,
|
||||
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
|
||||
|
||||
return adjusted_details
|
||||
|
||||
def get_chat_template_msg_variables(
|
||||
self, chat_template: str, field_messages: str
|
||||
) -> Set[str]:
|
||||
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
||||
return template_analyzer.get_message_vars(field_messages)
|
||||
|
||||
|
||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for instruction-based prompts.
|
||||
"""
|
||||
|
||||
_messages = "messages"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompter: ChatTemplatePrompter,
|
||||
prompter: "ChatTemplatePrompter",
|
||||
tokenizer,
|
||||
train_on_inputs,
|
||||
sequence_len,
|
||||
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
train_on_eos=None,
|
||||
):
|
||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||
self.prompter: ChatTemplatePrompter = prompter
|
||||
|
||||
self.roles_to_train = []
|
||||
if roles_to_train:
|
||||
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
self.train_on_eos = train_on_eos
|
||||
self.images = "images"
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
LOG.debug(
|
||||
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_batched(self) -> bool:
|
||||
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||
try:
|
||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||
isinstance(v, list) for v in prompt[self.messages]
|
||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||
)
|
||||
except KeyError:
|
||||
return False
|
||||
@@ -464,30 +485,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
turns = []
|
||||
optional_keys = [
|
||||
"tool_calls", # tool that 'assistant' calls
|
||||
"name", # name of tool given by 'tool'
|
||||
"tool_call_id", # mistral/mixtral requires this
|
||||
]
|
||||
for message in prompt[self.messages]:
|
||||
for message in prompt[self.prompter.field_messages]:
|
||||
transformed_message = self.transform_message(message)
|
||||
|
||||
turn = {
|
||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
||||
**transformed_message,
|
||||
"training": message.get(self.prompter.message_field_training),
|
||||
"training_detail": message.get(
|
||||
self.prompter.message_field_training_detail
|
||||
),
|
||||
}
|
||||
|
||||
# do not add content if None as it may conflict with some templates due to tools
|
||||
content = message.get(self.prompter.message_field_content, None)
|
||||
if content is not None:
|
||||
turn["content"] = content
|
||||
|
||||
for key in optional_keys:
|
||||
value = message.get(key, None)
|
||||
if value is not None:
|
||||
turn[key] = value
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||
@@ -495,6 +503,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
return turns
|
||||
|
||||
def transform_message(self, message):
|
||||
# Build the initial transformed message from the mappings
|
||||
transformed_message = {}
|
||||
for key, value in self.prompter.message_property_mappings.items():
|
||||
if message.get(value) is not None:
|
||||
transformed_message[key] = message[value]
|
||||
else:
|
||||
LOG.debug(
|
||||
f"Could not find value for property {value} in message: {message}"
|
||||
)
|
||||
|
||||
# Map the role if necessary
|
||||
if "role" in transformed_message:
|
||||
transformed_message["role"] = self.prompter.roles.get(
|
||||
transformed_message["role"], transformed_message["role"]
|
||||
)
|
||||
|
||||
# Determine which keys in the original message were not mapped
|
||||
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||
remaining_keys = set(message) - mapped_values
|
||||
|
||||
# Keep only the properties defined in the chat template
|
||||
# and not already mapped
|
||||
for key in self.prompter.chat_template_msg_variables:
|
||||
if key in remaining_keys:
|
||||
val = message.get(key)
|
||||
if val is not None:
|
||||
transformed_message[key] = val
|
||||
|
||||
return transformed_message
|
||||
|
||||
def get_images(self, prompt):
|
||||
return prompt.get(self.images, None)
|
||||
|
||||
@@ -516,33 +555,46 @@ class StrategyLoader:
|
||||
}
|
||||
|
||||
def __call__(
|
||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
||||
self,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
||||
processor=None,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
if ds_cfg is None:
|
||||
dataset_config = {}
|
||||
elif isinstance(ds_cfg, BaseModel):
|
||||
dataset_config = ds_cfg.model_dump()
|
||||
else:
|
||||
dataset_config = ds_cfg
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_property_mappings": dataset_config.get(
|
||||
"message_property_mappings", {}
|
||||
),
|
||||
"message_field_training": dataset_config.get(
|
||||
"message_field_training", None
|
||||
),
|
||||
"message_field_training_detail": dataset_config.get(
|
||||
"message_field_training_detail",
|
||||
None,
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||
"roles": dataset_config.get("roles"),
|
||||
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1,
|
||||
"processor": processor,
|
||||
}
|
||||
|
||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
||||
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||
strategy_cls = self._get_strategy_cls()
|
||||
|
||||
strategy = strategy_cls(
|
||||
@@ -551,9 +603,6 @@ class StrategyLoader:
|
||||
**strategy_params,
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||
message_property_mappings = ds_cfg.get(
|
||||
"message_property_mappings",
|
||||
{
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
)
|
||||
role_map_inv = ds_cfg.get(
|
||||
"roles",
|
||||
{
|
||||
@@ -40,18 +48,18 @@ def default(
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
"role": role_map[m[field_message_role]],
|
||||
"content": m[field_message_content],
|
||||
"role": role_map[m[message_property_mappings["role"]]],
|
||||
"content": m[message_property_mappings["content"]],
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
chosen = {
|
||||
"role": role_map[sample[field_chosen][field_message_role]],
|
||||
"content": sample[field_chosen][field_message_content],
|
||||
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||
}
|
||||
rejected = {
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
|
||||
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Module for inspect jinja templates for the variables they use"""
|
||||
from typing import Dict, Optional, Set, TypedDict, Union
|
||||
|
||||
from jinja2 import Environment, meta, nodes
|
||||
|
||||
|
||||
class JinjaTemplateAnalysis(TypedDict):
|
||||
"""
|
||||
Represents the detailed analysis of a Jinja template variable.
|
||||
|
||||
Attributes:
|
||||
accessed_properties (Set[str]): A set of properties accessed from the variable
|
||||
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
|
||||
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
|
||||
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
|
||||
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
|
||||
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
|
||||
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
|
||||
"""
|
||||
|
||||
accessed_properties: Set[str]
|
||||
accessed_indices: Set[Union[int, float]]
|
||||
is_iterated: bool
|
||||
is_conditional: bool
|
||||
iteration_source: Optional[str]
|
||||
iteration_target: Optional[Union[str, list[str]]]
|
||||
|
||||
|
||||
class JinjaTemplateAnalyzer:
|
||||
"""
|
||||
Analyzes Jinja templates to extract information about variable usage,
|
||||
including accessed properties, iteration, and conditional references.
|
||||
|
||||
Attributes:
|
||||
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
|
||||
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
|
||||
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
|
||||
|
||||
Methods:
|
||||
get_template_variables(template: str) -> Dict[str, Set[str]]:
|
||||
Parse a Jinja template and return a mapping of variables to their accessed properties.
|
||||
|
||||
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
Perform a detailed analysis of the template, including variable usage,
|
||||
iteration, and conditional references.
|
||||
|
||||
Private Methods:
|
||||
_visit_node(node) -> None:
|
||||
Recursively visit AST nodes to detect attribute access and iteration targets.
|
||||
|
||||
_get_base_name(node) -> Optional[str]:
|
||||
Extract the base variable name from a node.
|
||||
|
||||
_get_target_name(node) -> Optional[Union[str, list[str]]]:
|
||||
Extract the target name(s) from a `For` node.
|
||||
"""
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.env: Environment = Environment(autoescape=True)
|
||||
self.property_access: Dict[str, Set[str]] = {}
|
||||
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||
self.ast: nodes.Node = self.env.parse(template)
|
||||
self.template: str = template
|
||||
self.variable_assignments: Dict[str, str] = {}
|
||||
|
||||
def _visit_node(self, node) -> None:
|
||||
"""Recursively visit AST nodes to find attribute access."""
|
||||
# Handle attribute access (dot notation)
|
||||
if isinstance(node, nodes.Getattr):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
self.property_access.setdefault(base_name, set()).add(node.attr)
|
||||
|
||||
# Handle dictionary access (subscript notation)
|
||||
elif isinstance(node, nodes.Getitem):
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name and isinstance(node.arg, nodes.Const):
|
||||
value = node.arg.value
|
||||
if isinstance(value, (int, float)):
|
||||
self.index_access.setdefault(base_name, set()).add(value)
|
||||
else:
|
||||
self.property_access.setdefault(base_name, set()).add(value)
|
||||
|
||||
elif isinstance(node, nodes.Test) and node.name == "defined":
|
||||
base_name = self._get_base_name(node.node)
|
||||
if base_name:
|
||||
if isinstance(node.node, nodes.Getattr):
|
||||
self.property_access.setdefault(base_name, set()).add(
|
||||
node.node.attr
|
||||
)
|
||||
|
||||
# Handle loop variables
|
||||
elif isinstance(node, nodes.For):
|
||||
iter_name = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_name and target_name:
|
||||
self.iteration_targets[target_name] = iter_name
|
||||
self.property_access.setdefault(iter_name, set())
|
||||
|
||||
elif isinstance(node, nodes.Assign):
|
||||
target_name = self._get_target_name(node.target)
|
||||
source_name = self._get_base_name(node.node)
|
||||
if target_name and source_name:
|
||||
self.variable_assignments[target_name] = source_name
|
||||
|
||||
elif isinstance(node, nodes.Filter):
|
||||
if node.name == "selectattr":
|
||||
target = self._get_base_name(node.node)
|
||||
if target:
|
||||
self.variable_assignments[f"filtered_{target}"] = target
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
self._visit_node(child)
|
||||
|
||||
def _get_target_name(self, node) -> Optional[str]:
|
||||
"""Get the target variable name from a For node.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
- str: For simple variable targets (e.g., "item" in "for item in items")
|
||||
- None: If the node type is not recognized or is a tuple
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
return None
|
||||
|
||||
def _get_target_names(self, node) -> list[str]:
|
||||
"""Get all target variable names from a For node, including tuple unpacking.
|
||||
|
||||
Args:
|
||||
node: A Jinja AST node representing either a Name or Tuple node
|
||||
|
||||
Returns:
|
||||
List of target variable names
|
||||
"""
|
||||
if isinstance(node, nodes.Name):
|
||||
return [node.name]
|
||||
|
||||
if isinstance(node, nodes.Tuple):
|
||||
names = []
|
||||
for n in node.items:
|
||||
if isinstance(n, nodes.Name):
|
||||
names.append(n.name)
|
||||
return names
|
||||
|
||||
return []
|
||||
|
||||
def _get_base_name(self, node) -> Optional[str]:
|
||||
"""Get the base variable name from a node."""
|
||||
if isinstance(node, nodes.Name):
|
||||
return node.name
|
||||
|
||||
if isinstance(node, nodes.Getattr):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
if isinstance(node, nodes.Getitem):
|
||||
return self._get_base_name(node.node)
|
||||
|
||||
return None
|
||||
|
||||
def get_template_variables(self) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Parse a Jinja template and return both variables and their accessed properties.
|
||||
|
||||
Args:
|
||||
template (str): The Jinja template string
|
||||
|
||||
Returns:
|
||||
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
|
||||
"""
|
||||
# Parse the template
|
||||
ast = self.env.parse(self.template)
|
||||
|
||||
# Get all undeclared variables
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
|
||||
# Reset property access tracking
|
||||
self.property_access = {}
|
||||
|
||||
# Visit all nodes to find property access
|
||||
self._visit_node(ast)
|
||||
|
||||
# Create result dictionary
|
||||
result: Dict[str, Set[str]] = {var: set() for var in variables}
|
||||
# Merge in any discovered sub-properties
|
||||
for var, props in self.property_access.items():
|
||||
if var not in result:
|
||||
result[var] = set()
|
||||
result[var].update(props)
|
||||
|
||||
return result
|
||||
|
||||
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
|
||||
"""
|
||||
Provide a detailed analysis of template variables and their usage.
|
||||
"""
|
||||
variables = self.get_template_variables()
|
||||
self.iteration_targets = {}
|
||||
|
||||
analysis: Dict[str, JinjaTemplateAnalysis] = {
|
||||
var: JinjaTemplateAnalysis(
|
||||
accessed_properties=props,
|
||||
accessed_indices=set(),
|
||||
is_iterated=False,
|
||||
is_conditional=False,
|
||||
iteration_source=None,
|
||||
iteration_target=None,
|
||||
)
|
||||
for var, props in variables.items()
|
||||
}
|
||||
|
||||
for var, indices in self.index_access.items():
|
||||
if var in analysis:
|
||||
analysis[var]["accessed_indices"] = indices
|
||||
|
||||
def visit_node(node):
|
||||
if isinstance(node, nodes.If):
|
||||
|
||||
def find_test_vars(test_node):
|
||||
if isinstance(test_node, nodes.Name):
|
||||
if test_node.name in analysis:
|
||||
analysis[test_node.name]["is_conditional"] = True
|
||||
for child in test_node.iter_child_nodes():
|
||||
find_test_vars(child)
|
||||
|
||||
find_test_vars(node.test)
|
||||
|
||||
if isinstance(node, nodes.For):
|
||||
iter_target = self._get_base_name(node.iter)
|
||||
target_name = self._get_target_name(node.target)
|
||||
if iter_target in analysis:
|
||||
analysis[iter_target]["is_iterated"] = True
|
||||
if target_name:
|
||||
analysis[iter_target]["iteration_target"] = target_name
|
||||
if isinstance(target_name, str) and target_name not in analysis:
|
||||
analysis[target_name] = {
|
||||
"accessed_properties": set(),
|
||||
"is_iterated": False,
|
||||
"is_conditional": False,
|
||||
"iteration_source": iter_target,
|
||||
"iteration_target": None,
|
||||
}
|
||||
|
||||
for child in node.iter_child_nodes():
|
||||
visit_node(child)
|
||||
|
||||
visit_node(self.ast)
|
||||
return analysis
|
||||
|
||||
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
|
||||
"""
|
||||
Get all properties accessed on a variable and its downstream assignments.
|
||||
|
||||
Args:
|
||||
start_var: The starting variable to trace
|
||||
|
||||
Returns:
|
||||
Dict mapping variable names to their accessed properties
|
||||
"""
|
||||
visited = set()
|
||||
properties = {}
|
||||
|
||||
def trace_variable(var_name: str):
|
||||
if var_name in visited:
|
||||
return
|
||||
visited.add(var_name)
|
||||
|
||||
# Get direct properties
|
||||
if var_name in self.property_access:
|
||||
properties[var_name] = self.property_access[var_name]
|
||||
|
||||
# Get properties from iteration targets
|
||||
if var_name in self.iteration_targets:
|
||||
target = self.iteration_targets[var_name]
|
||||
if isinstance(target, str):
|
||||
trace_variable(target)
|
||||
elif isinstance(target, list):
|
||||
for t in target:
|
||||
trace_variable(t)
|
||||
|
||||
# Follow assignments
|
||||
for target, source in self.variable_assignments.items():
|
||||
if source == var_name:
|
||||
trace_variable(target)
|
||||
|
||||
# Check for array slicing
|
||||
analysis = self.analyze_template()
|
||||
if var_name in analysis:
|
||||
var_info = analysis[var_name]
|
||||
if var_info["accessed_indices"]:
|
||||
# If this variable is sliced, follow the resulting assignment
|
||||
slice_result = f"{var_name}_slice"
|
||||
if slice_result in self.property_access:
|
||||
trace_variable(slice_result)
|
||||
|
||||
trace_variable(start_var)
|
||||
return properties
|
||||
|
||||
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
|
||||
"""
|
||||
Get all properties accessed on messages and derived variables.
|
||||
"""
|
||||
all_properties = self.get_downstream_properties(field_messages)
|
||||
|
||||
# Combine all properties from all related variables
|
||||
combined_properties = set()
|
||||
for properties in all_properties.values():
|
||||
combined_properties.update(properties)
|
||||
|
||||
# Also include properties from the message iteration variable
|
||||
analysis = self.analyze_template()
|
||||
if "message" in analysis:
|
||||
combined_properties.update(analysis["message"]["accessed_properties"])
|
||||
|
||||
return combined_properties
|
||||
@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_property_mappings = ds_cfg.get("message_property_mappings")
|
||||
message_field_role = (
|
||||
message_property_mappings.get("role") if message_property_mappings else None
|
||||
)
|
||||
message_field_content = (
|
||||
message_property_mappings.get("content") if message_property_mappings else None
|
||||
)
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_chat_template(
|
||||
user_choice: str,
|
||||
jinja_template: Optional[str] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_chat_template(
|
||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||
f"Please add a chat_template in tokenizer config"
|
||||
)
|
||||
return tokenizer.chat_template
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||
if not tokenizer:
|
||||
@@ -78,7 +78,7 @@ def get_chat_template(
|
||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||
)
|
||||
if tokenizer.chat_template:
|
||||
return tokenizer.chat_template
|
||||
return tokenizer.chat_template # type: ignore
|
||||
|
||||
user_choice = user_choice[
|
||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
|
||||
@@ -258,7 +259,7 @@ def validate_config(
|
||||
cfg: DictDefault,
|
||||
capabilities: Optional[dict] = None,
|
||||
env_capabilities: Optional[dict] = None,
|
||||
):
|
||||
) -> DictDefault:
|
||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||
AxolotlInputConfig = AxolotlInputConfigBase
|
||||
|
||||
@@ -268,6 +269,16 @@ def validate_config(
|
||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||
) = merge_input_args()
|
||||
|
||||
# Convert datasets to proper format if needed
|
||||
if cfg.get("datasets"):
|
||||
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
||||
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||
elif not isinstance(ds_cfg, SFTDataset):
|
||||
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||
|
||||
if capabilities or env_capabilities:
|
||||
if (capabilities and env_capabilities is None) or (
|
||||
env_capabilities and capabilities is None
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import (
|
||||
Field,
|
||||
StringConstraints,
|
||||
conlist,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
@@ -186,8 +187,13 @@ class SFTDataset(BaseModel):
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
message_field_role: Optional[str] = None
|
||||
message_field_content: Optional[str] = None
|
||||
message_field_role: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_field_content: Optional[
|
||||
str
|
||||
] = None # deprecated, use message_property_mappings
|
||||
message_property_mappings: Optional[Dict[str, str]] = None
|
||||
message_field_training: Optional[str] = None
|
||||
message_field_training_detail: Optional[str] = None
|
||||
logprobs_field: Optional[str] = None
|
||||
@@ -199,9 +205,18 @@ class SFTDataset(BaseModel):
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_legacy_message_fields(cls, data):
|
||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||
return handle_legacy_message_fields_logic(data)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
@@ -241,6 +256,7 @@ class DPODataset(BaseModel):
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
revision: Optional[str] = None
|
||||
field_messages: Optional[str] = None
|
||||
|
||||
|
||||
class StepwiseSupervisedDataset(BaseModel):
|
||||
@@ -277,6 +293,9 @@ class KTODataset(BaseModel):
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
@@ -680,17 +699,15 @@ class AxolotlInputConfig(
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
dpo_use_logits_to_keep: Optional[bool] = None
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
||||
datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
|
||||
shuffle_merged_datasets: Optional[bool] = True
|
||||
dataset_prepared_path: Optional[str] = None
|
||||
dataset_shard_num: Optional[int] = None
|
||||
dataset_shard_idx: Optional[int] = None
|
||||
skip_prepare_dataset: Optional[bool] = False
|
||||
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
] = Field(
|
||||
pretraining_dataset: Optional[conlist(Union[PretrainingDataset, SFTDataset], min_length=1)] = Field( # type: ignore
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
@@ -895,10 +912,15 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg.get("type"):
|
||||
# Handle both dict and pydantic model cases
|
||||
ds_type = (
|
||||
ds_cfg.get("type")
|
||||
if isinstance(ds_cfg, dict)
|
||||
else getattr(ds_cfg, "type", None)
|
||||
)
|
||||
if not ds_type:
|
||||
continue
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
@@ -910,6 +932,14 @@ class AxolotlInputConfig(
|
||||
|
||||
return datasets
|
||||
|
||||
@field_serializer("datasets")
|
||||
def datasets_serializer(
|
||||
self, ds_configs: Optional[List[DatasetConfig]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
if ds_configs:
|
||||
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_size_fields(cls, data):
|
||||
@@ -1762,3 +1792,77 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
else:
|
||||
data["torch_compile"] = False
|
||||
return data
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
"""
|
||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||
|
||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||
- message_field_role: Mapped to the role field
|
||||
- message_field_content: Mapped to the content field
|
||||
|
||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||
message_property_mappings:
|
||||
role: source_role_field
|
||||
content: source_content_field
|
||||
additional_field: source_field
|
||||
|
||||
Args:
|
||||
data: Dictionary containing configuration data
|
||||
|
||||
Returns:
|
||||
Updated dictionary with message field mappings consolidated
|
||||
|
||||
Raises:
|
||||
ValueError: If there are conflicts between legacy and new mappings
|
||||
"""
|
||||
data = data.copy() # Create a copy to avoid modifying the original
|
||||
|
||||
if data.get("message_property_mappings") is None:
|
||||
data["message_property_mappings"] = {}
|
||||
|
||||
# Check for conflicts and handle role
|
||||
if "message_field_role" in data:
|
||||
LOG.warning(
|
||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||
)
|
||||
if (
|
||||
"role" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||
)
|
||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||
|
||||
del data["message_field_role"]
|
||||
elif "role" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["role"] = "role"
|
||||
|
||||
# Check for conflicts and handle content
|
||||
if "message_field_content" in data:
|
||||
LOG.warning(
|
||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||
)
|
||||
if (
|
||||
"content" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["content"]
|
||||
!= data["message_field_content"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||
)
|
||||
data["message_property_mappings"]["content"] = (
|
||||
data["message_field_content"] or "content"
|
||||
)
|
||||
|
||||
del data["message_field_content"]
|
||||
elif "content" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["content"] = "content"
|
||||
|
||||
return data
|
||||
|
||||
@@ -180,6 +180,7 @@ def load_tokenized_prepared_datasets(
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||
tokenizer_name = cfg.tokenizer_config
|
||||
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
|
||||
@@ -13,3 +13,26 @@ class DictDefault(Dict):
|
||||
|
||||
def __or__(self, other):
|
||||
return DictDefault(super().__ror__(other))
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
# workaround for pickle/unpickle issues and __frozen not being available
|
||||
try:
|
||||
isFrozen = hasattr( # pylint: disable=invalid-name
|
||||
self, "__frozen"
|
||||
) and object.__getattribute__(self, "__frozen")
|
||||
except AttributeError:
|
||||
isFrozen = False # pylint: disable=invalid-name
|
||||
|
||||
if isFrozen and name not in super().keys():
|
||||
raise KeyError(name)
|
||||
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
|
||||
try:
|
||||
p = object.__getattribute__(self, "__parent")
|
||||
key = object.__getattribute__(self, "__key")
|
||||
except AttributeError:
|
||||
p = None
|
||||
key = None
|
||||
if p is not None:
|
||||
p[key] = self
|
||||
object.__delattr__(self, "__parent")
|
||||
object.__delattr__(self, "__key")
|
||||
|
||||
@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
@@ -109,6 +110,7 @@ class TestKnowledgeDistillation:
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
@@ -76,7 +76,9 @@ class TestFAXentropyLlama:
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -73,6 +73,8 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_preference_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -108,6 +110,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -153,6 +157,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -198,6 +204,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -242,6 +250,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -289,6 +299,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -353,6 +365,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -56,6 +56,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -65,6 +65,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -118,6 +120,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -157,6 +161,8 @@ class TestFalcon(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
@@ -56,6 +56,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -99,6 +101,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -138,6 +142,8 @@ class TestLlama:
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard
|
||||
@@ -69,6 +69,8 @@ class TestPretrainLlama:
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -62,6 +62,8 @@ class TestLlamaVision(unittest.TestCase):
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestMamba(unittest.TestCase):
|
||||
"save_safetensors": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -63,6 +63,8 @@ class TestMistral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -106,6 +108,8 @@ class TestMistral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -69,6 +69,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -123,6 +125,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -180,6 +184,8 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -233,6 +239,8 @@ class TestMixtral(unittest.TestCase):
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
@@ -281,6 +289,8 @@ class TestMixtral(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -103,6 +105,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
@@ -139,6 +143,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
@@ -59,6 +59,8 @@ class TestPackedLlama(unittest.TestCase):
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
@@ -61,6 +61,7 @@ class TestPhi(unittest.TestCase):
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -40,8 +40,10 @@ class TestE2eQwen:
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
@@ -66,6 +66,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
"use_tensorboard": True,
|
||||
}
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -7,6 +7,7 @@ from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||
|
||||
|
||||
@@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
||||
modified_template = template.replace(old_date_logic, new_date_logic)
|
||||
|
||||
return modified_template
|
||||
|
||||
|
||||
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
|
||||
def fixture_chat_template_jinja_with_optional_fields() -> str:
|
||||
return """{% for message in messages %}
|
||||
{{'<|im_start|>'}}{{ message['role'] }}
|
||||
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
|
||||
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
|
||||
{{ message['content'] }}{{'<|im_end|>'}}
|
||||
{% endfor %}"""
|
||||
|
||||
|
||||
@pytest.fixture(name="basic_jinja_template_analyzer")
|
||||
def basic_jinja_template_analyzer():
|
||||
return JinjaTemplateAnalyzer(
|
||||
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'user' %}{{'<|user|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
|
||||
' + message['content'] + '<|end|>
|
||||
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
|
||||
' }}{% else %}{{ eos_token }}{% endif %}"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="mistral_jinja_template_analyzer")
|
||||
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
|
||||
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)
|
||||
|
||||
@@ -38,6 +38,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
"chat_template": "llama3",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -74,8 +78,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -86,7 +92,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
# fmt: off
|
||||
@@ -114,8 +120,10 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
phi35_tokenizer,
|
||||
chat_template=get_chat_template("phi_35"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -126,7 +134,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -170,9 +178,11 @@ class TestAssistantChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_field_training="training",
|
||||
message_property_mappings={
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
roles={
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
@@ -185,7 +195,7 @@ class TestAssistantChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "messages"
|
||||
|
||||
prompt_tokens = strategy.prompter.build_prompt(
|
||||
assistant_dataset[0]["messages"], False
|
||||
)
|
||||
@@ -230,8 +240,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -239,7 +252,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["gpt"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -287,8 +300,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -296,7 +312,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -344,8 +360,11 @@ class TestSharegptChatTemplateLlama3:
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
chat_template=get_chat_template("llama3"),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={
|
||||
"role": "from",
|
||||
"content": "value",
|
||||
},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -353,7 +372,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
sequence_len=512,
|
||||
roles_to_train=["system", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
labels = res["labels"]
|
||||
@@ -417,8 +436,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -486,8 +504,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
chat_template=get_chat_template(
|
||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||
),
|
||||
message_field_role="role",
|
||||
message_field_content="content",
|
||||
message_property_mappings={"role": "role", "content": "content"},
|
||||
),
|
||||
tokenizer=llama3_tokenizer,
|
||||
train_on_inputs=False,
|
||||
|
||||
@@ -3,7 +3,6 @@ tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -123,15 +122,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -180,15 +179,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -241,20 +240,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant", "human"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -307,15 +301,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=True,
|
||||
sequence_len=512,
|
||||
roles_to_train=["human", "assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -360,8 +354,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -369,7 +363,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=[],
|
||||
train_on_eos="none", # Add this line
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
|
||||
@@ -400,8 +394,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -409,7 +403,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="all",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -446,8 +440,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -455,7 +449,6 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="turn",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
@@ -526,8 +519,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -535,7 +528,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="last",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -578,8 +571,8 @@ class TestChatTemplateConfigurations:
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -587,7 +580,7 @@ class TestChatTemplateConfigurations:
|
||||
roles_to_train=["assistant"],
|
||||
train_on_eos="none",
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
labels = res["labels"]
|
||||
input_ids = res["input_ids"]
|
||||
@@ -624,15 +617,15 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
drop_system_message=True,
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
field_messages="conversations",
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
sequence_len=512,
|
||||
roles_to_train=["assistant"],
|
||||
)
|
||||
strategy.messages = "conversations"
|
||||
|
||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||
input_ids = res["input_ids"]
|
||||
|
||||
@@ -668,8 +661,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template, jinja_template=chat_template_jinja
|
||||
),
|
||||
roles=custom_roles,
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -741,8 +733,7 @@ class TestChatTemplateConfigurations:
|
||||
),
|
||||
message_field_training="train",
|
||||
message_field_training_detail="train_detail",
|
||||
message_field_role="from",
|
||||
message_field_content="value",
|
||||
message_property_mappings={"role": "from", "content": "value"},
|
||||
),
|
||||
tokenizer=tokenizer,
|
||||
train_on_inputs=False,
|
||||
@@ -911,6 +902,64 @@ class TestChatTemplateConfigurations:
|
||||
LOG.debug(f"Final labels: {labels}")
|
||||
LOG.debug(f"Final input_ids: {input_ids}")
|
||||
|
||||
def test_get_chat_template_variables(
|
||||
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
):
|
||||
LOG.info("Testing get_chat_template_variables")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
|
||||
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||
)
|
||||
|
||||
prompter = ChatTemplatePrompter(
|
||||
actual_tokenizer,
|
||||
chat_template=get_chat_template(
|
||||
chat_template, jinja_template=actual_jinja_template
|
||||
),
|
||||
message_property_mappings={"from": "role", "value": "content"},
|
||||
)
|
||||
|
||||
variables = prompter.get_chat_template_msg_variables(
|
||||
actual_jinja_template
|
||||
if actual_jinja_template
|
||||
else actual_tokenizer.get_chat_template(),
|
||||
"messages",
|
||||
)
|
||||
|
||||
if chat_template == "llama3":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "chatml":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
elif chat_template == "phi_35":
|
||||
assert variables == {"role", "content"}, (
|
||||
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||
f"Got: {variables}\n"
|
||||
f"Chat template: {actual_jinja_template}"
|
||||
)
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||
)
|
||||
|
||||
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
class TestJinjaTemplateAnalyzer:
|
||||
"""
|
||||
tests for jinja_template_analyzer
|
||||
"""
|
||||
|
||||
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
|
||||
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
|
||||
"""Test that all top-level variables are correctly extracted."""
|
||||
LOG.info("Testing with train_on_inputs=True")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_template_variables()
|
||||
expected_vars = {
|
||||
"messages",
|
||||
"content",
|
||||
"eos_token",
|
||||
"message",
|
||||
"tools",
|
||||
"system_message",
|
||||
"loop_messages",
|
||||
"ns",
|
||||
"tool_call",
|
||||
"tool",
|
||||
"loop",
|
||||
"bos_token",
|
||||
"raise_exception",
|
||||
}
|
||||
assert set(variables.keys()) == expected_vars
|
||||
message_vars = variables["message"]
|
||||
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
def test_message_property_access(self, basic_jinja_template_analyzer):
|
||||
"""Test that properties accessed on 'message' variable are correctly identified."""
|
||||
LOG.info("Testing message property access")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||
assert "messages" in variables
|
||||
assert "message" in variables
|
||||
assert "role" in variables["message"]
|
||||
assert "content" in variables["message"]
|
||||
|
||||
def test_detailed_analysis(self, basic_jinja_template_analyzer):
|
||||
"""Test the detailed analysis of variable usage."""
|
||||
LOG.info("Testing detailed analysis")
|
||||
|
||||
analysis = basic_jinja_template_analyzer.analyze_template()
|
||||
|
||||
assert analysis["messages"]["is_iterated"] is True
|
||||
assert "role" in analysis["message"]["accessed_properties"]
|
||||
assert "content" in analysis["message"]["accessed_properties"]
|
||||
|
||||
assert analysis["add_generation_prompt"]["is_conditional"] is True
|
||||
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
|
||||
|
||||
assert not analysis["eos_token"]["is_iterated"]
|
||||
assert len(analysis["eos_token"]["accessed_properties"]) == 0
|
||||
|
||||
def test_nested_property_access(self):
|
||||
"""Test handling of nested property access."""
|
||||
LOG.info("Testing nested property access")
|
||||
|
||||
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "profile" in variables["user"]
|
||||
assert "settings" in variables["user"]
|
||||
|
||||
def test_loop_variable_handling(self):
|
||||
"""Test handling of loop variables and their properties."""
|
||||
LOG.info("Testing loop variable handling")
|
||||
|
||||
template = """
|
||||
{% for item in items %}
|
||||
{{ item.name }}
|
||||
{% for subitem in item.subitems %}
|
||||
{{ subitem.value }}
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["items"]["is_iterated"]
|
||||
assert "name" in analysis["item"]["accessed_properties"]
|
||||
assert "subitems" in analysis["item"]["accessed_properties"]
|
||||
|
||||
def test_conditional_variable_usage(self):
|
||||
"""Test detection of variables used in conditional statements."""
|
||||
LOG.info("Testing conditional variable usage")
|
||||
|
||||
template = """
|
||||
{% if user.is_admin and config.debug_mode %}
|
||||
{{ debug_info }}
|
||||
{% endif %}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
analysis = analyzer.analyze_template()
|
||||
|
||||
assert analysis["user"]["is_conditional"]
|
||||
assert analysis["config"]["is_conditional"]
|
||||
assert "is_admin" in analysis["user"]["accessed_properties"]
|
||||
assert "debug_mode" in analysis["config"]["accessed_properties"]
|
||||
|
||||
def test_complex_expressions(self):
|
||||
"""Test handling of complex expressions and filters."""
|
||||
LOG.info("Testing complex expressions and filters")
|
||||
|
||||
template = """
|
||||
{{ user.name | upper }}
|
||||
{{ messages | length > 0 and messages[0].content }}
|
||||
{{ data['key'].nested['value'] }}
|
||||
"""
|
||||
analyzer = JinjaTemplateAnalyzer(template)
|
||||
variables = analyzer.get_template_variables()
|
||||
|
||||
assert "user" in variables
|
||||
assert "name" in variables["user"]
|
||||
assert "messages" in variables
|
||||
assert "content" in variables["messages"]
|
||||
assert "data" in variables
|
||||
|
||||
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
|
||||
"""Test that the basic message variables are correctly identified."""
|
||||
LOG.info("Testing basic message variables")
|
||||
|
||||
variables = basic_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content"}
|
||||
|
||||
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
|
||||
"""Test that the mixtral message variables are correctly identified."""
|
||||
LOG.info("Testing mixtral message variables")
|
||||
|
||||
variables = mistral_jinja_template_analyzer.get_message_vars()
|
||||
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -302,3 +302,22 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_message_property_mappings(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"message_property_mappings": {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestModelsUtils:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="")
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
@@ -116,3 +116,79 @@ class TestModelsUtils:
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
"""Test message property mapping configuration validation"""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||
|
||||
# Test legacy fields are mapped orrectly
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_field_content="content_field",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test direct message_property_mapping works
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_property_mappings={
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "custom_role",
|
||||
"content": "custom_content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they match
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="same_role",
|
||||
message_property_mappings={"role": "same_role"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "same_role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test both legacy and new fields work when they don't overlap
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="role_field",
|
||||
message_property_mappings={"content": "content_field"},
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role_field",
|
||||
"content": "content_field",
|
||||
}
|
||||
|
||||
# Test no role or content provided
|
||||
dataset = SFTDataset(
|
||||
path="test_path",
|
||||
)
|
||||
assert dataset.message_property_mappings == {
|
||||
"role": "role",
|
||||
"content": "content",
|
||||
}
|
||||
|
||||
# Test error when legacy and new fields conflict
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_role="legacy_role",
|
||||
message_property_mappings={"role": "different_role"},
|
||||
)
|
||||
assert "Conflicting message role fields" in str(exc_info.value)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
SFTDataset(
|
||||
path="test_path",
|
||||
message_field_content="legacy_content",
|
||||
message_property_mappings={"content": "different_content"},
|
||||
)
|
||||
assert "Conflicting message content fields" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user