feat: add config for optional parameters in a chat message (#2260)

* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
NJordan72
2025-02-17 21:59:27 -05:00
committed by GitHub
parent 3aac3b1da9
commit b194e17c28
51 changed files with 1190 additions and 230 deletions

View File

@@ -142,10 +142,19 @@ datasets:
# Key containing the messages (default: "messages")
field_messages: messages
# Key for role in each message (default: "role")
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content
# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...
message_property_mappings:
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:

View File

@@ -42,8 +42,9 @@ datasets:
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
@@ -52,8 +53,9 @@ datasets:
type: chat_template
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
```
We recommend checking the below examples for other usecases.
@@ -138,8 +140,9 @@ datasets:
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train

View File

@@ -114,7 +114,7 @@ A flow chart is as follows:
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a Github Discussion.
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
::: {.callout-tip}
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
@@ -289,9 +289,10 @@ If your dataset format is different, here are the keys you should check (with th
```yaml
datasets:
...
field_messages: messages
message_field_role: role
message_field_content: content
field_messages: messages # this should point to the key containing the list of conversations
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
role: role
content: content
```
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
@@ -348,13 +349,14 @@ datasets:
- path: A.jsonl
type: chat_template
# step 1
# step 1
chat_template: chatml
# step 2
field_messages: messages
message_field_role: role
message_field_content: content
# step 2
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
assistant:
@@ -365,8 +367,8 @@ datasets:
- human
- user
# step 3
roles_to_train: ["assistant"]
# step 3
roles_to_train: ["assistant"]
train_on_eos: "turn"
special_tokens:

View File

@@ -23,3 +23,7 @@ description: Frequently asked questions
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.

View File

@@ -229,8 +229,9 @@ datasets:
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_field_role: "role"
message_field_content: "content"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]

View File

@@ -21,8 +21,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -16,8 +16,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -13,8 +13,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -17,8 +17,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02

View File

@@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user

View File

@@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
@@ -31,8 +32,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -22,8 +22,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user

View File

@@ -12,8 +12,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system

View File

@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys()
message_field_role = None
message_property_mappings = {"role": None, "content": None}
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
message_property_mappings["role"] = key
break
if not message_field_role:
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role
message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
message_property_mappings["content"] = key
break
if not message_field_content:
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
ds_cfg["message_property_mappings"] = message_property_mappings
print(yaml.dump({"datasets": [ds_cfg]}))

View File

@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -34,15 +34,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
max_length = self.prompter.max_length
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
@@ -55,17 +52,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:max_length
]
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
chat_template: str,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "role",
message_field_content: str = "content",
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
message_property_mappings = {
"role": "role",
"content": "content",
}
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
"tool": "tool",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template, field_messages
)
self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
return adjusted_details
def get_chat_template_msg_variables(
self, chat_template: str, field_messages: str
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(field_messages)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
_messages = "messages"
def __init__(
self,
prompter: ChatTemplatePrompter,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs,
sequence_len,
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
@property
def supports_batched(self) -> bool:
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
@@ -464,30 +485,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content
for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":
@@ -495,6 +503,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
if message.get(value) is not None:
transformed_message[key] = message[value]
else:
LOG.debug(
f"Could not find value for property {value} in message: {message}"
)
# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val
return transformed_message
def get_images(self, prompt):
return prompt.get(self.images, None)
@@ -516,33 +555,46 @@ class StrategyLoader:
}
def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
if ds_cfg is None:
dataset_config = {}
elif isinstance(ds_cfg, BaseModel):
dataset_config = ds_cfg.model_dump()
else:
dataset_config = ds_cfg
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get(
"message_field_training", None
),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"field_messages": dataset_config.get("field_messages", "messages"),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
@@ -551,9 +603,6 @@ class StrategyLoader:
**strategy_params,
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
message_property_mappings = ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
)
role_map_inv = ds_cfg.get(
"roles",
{
@@ -40,18 +48,18 @@ def default(
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
"role": role_map[m[message_property_mappings["role"]]],
"content": m[message_property_mappings["content"]],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
"content": sample[field_chosen][message_property_mappings["content"]],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
"content": sample[field_rejected][message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}

View File

@@ -0,0 +1,318 @@
"""Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes
class JinjaTemplateAnalysis(TypedDict):
"""
Represents the detailed analysis of a Jinja template variable.
Attributes:
accessed_properties (Set[str]): A set of properties accessed from the variable
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
"""
accessed_properties: Set[str]
accessed_indices: Set[Union[int, float]]
is_iterated: bool
is_conditional: bool
iteration_source: Optional[str]
iteration_target: Optional[Union[str, list[str]]]
class JinjaTemplateAnalyzer:
"""
Analyzes Jinja templates to extract information about variable usage,
including accessed properties, iteration, and conditional references.
Attributes:
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
Methods:
get_template_variables(template: str) -> Dict[str, Set[str]]:
Parse a Jinja template and return a mapping of variables to their accessed properties.
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
Perform a detailed analysis of the template, including variable usage,
iteration, and conditional references.
Private Methods:
_visit_node(node) -> None:
Recursively visit AST nodes to detect attribute access and iteration targets.
_get_base_name(node) -> Optional[str]:
Extract the base variable name from a node.
_get_target_name(node) -> Optional[Union[str, list[str]]]:
Extract the target name(s) from a `For` node.
"""
def __init__(self, template: str):
self.env: Environment = Environment(autoescape=True)
self.property_access: Dict[str, Set[str]] = {}
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
self.index_access: Dict[str, Set[Union[int, float]]] = {}
self.ast: nodes.Node = self.env.parse(template)
self.template: str = template
self.variable_assignments: Dict[str, str] = {}
def _visit_node(self, node) -> None:
"""Recursively visit AST nodes to find attribute access."""
# Handle attribute access (dot notation)
if isinstance(node, nodes.Getattr):
base_name = self._get_base_name(node.node)
if base_name:
self.property_access.setdefault(base_name, set()).add(node.attr)
# Handle dictionary access (subscript notation)
elif isinstance(node, nodes.Getitem):
base_name = self._get_base_name(node.node)
if base_name and isinstance(node.arg, nodes.Const):
value = node.arg.value
if isinstance(value, (int, float)):
self.index_access.setdefault(base_name, set()).add(value)
else:
self.property_access.setdefault(base_name, set()).add(value)
elif isinstance(node, nodes.Test) and node.name == "defined":
base_name = self._get_base_name(node.node)
if base_name:
if isinstance(node.node, nodes.Getattr):
self.property_access.setdefault(base_name, set()).add(
node.node.attr
)
# Handle loop variables
elif isinstance(node, nodes.For):
iter_name = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_name and target_name:
self.iteration_targets[target_name] = iter_name
self.property_access.setdefault(iter_name, set())
elif isinstance(node, nodes.Assign):
target_name = self._get_target_name(node.target)
source_name = self._get_base_name(node.node)
if target_name and source_name:
self.variable_assignments[target_name] = source_name
elif isinstance(node, nodes.Filter):
if node.name == "selectattr":
target = self._get_base_name(node.node)
if target:
self.variable_assignments[f"filtered_{target}"] = target
for child in node.iter_child_nodes():
self._visit_node(child)
def _get_target_name(self, node) -> Optional[str]:
"""Get the target variable name from a For node.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
- str: For simple variable targets (e.g., "item" in "for item in items")
- None: If the node type is not recognized or is a tuple
"""
if isinstance(node, nodes.Name):
return node.name
return None
def _get_target_names(self, node) -> list[str]:
"""Get all target variable names from a For node, including tuple unpacking.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
List of target variable names
"""
if isinstance(node, nodes.Name):
return [node.name]
if isinstance(node, nodes.Tuple):
names = []
for n in node.items:
if isinstance(n, nodes.Name):
names.append(n.name)
return names
return []
def _get_base_name(self, node) -> Optional[str]:
"""Get the base variable name from a node."""
if isinstance(node, nodes.Name):
return node.name
if isinstance(node, nodes.Getattr):
return self._get_base_name(node.node)
if isinstance(node, nodes.Getitem):
return self._get_base_name(node.node)
return None
def get_template_variables(self) -> Dict[str, Set[str]]:
"""
Parse a Jinja template and return both variables and their accessed properties.
Args:
template (str): The Jinja template string
Returns:
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
"""
# Parse the template
ast = self.env.parse(self.template)
# Get all undeclared variables
variables = meta.find_undeclared_variables(ast)
# Reset property access tracking
self.property_access = {}
# Visit all nodes to find property access
self._visit_node(ast)
# Create result dictionary
result: Dict[str, Set[str]] = {var: set() for var in variables}
# Merge in any discovered sub-properties
for var, props in self.property_access.items():
if var not in result:
result[var] = set()
result[var].update(props)
return result
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
"""
Provide a detailed analysis of template variables and their usage.
"""
variables = self.get_template_variables()
self.iteration_targets = {}
analysis: Dict[str, JinjaTemplateAnalysis] = {
var: JinjaTemplateAnalysis(
accessed_properties=props,
accessed_indices=set(),
is_iterated=False,
is_conditional=False,
iteration_source=None,
iteration_target=None,
)
for var, props in variables.items()
}
for var, indices in self.index_access.items():
if var in analysis:
analysis[var]["accessed_indices"] = indices
def visit_node(node):
if isinstance(node, nodes.If):
def find_test_vars(test_node):
if isinstance(test_node, nodes.Name):
if test_node.name in analysis:
analysis[test_node.name]["is_conditional"] = True
for child in test_node.iter_child_nodes():
find_test_vars(child)
find_test_vars(node.test)
if isinstance(node, nodes.For):
iter_target = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_target in analysis:
analysis[iter_target]["is_iterated"] = True
if target_name:
analysis[iter_target]["iteration_target"] = target_name
if isinstance(target_name, str) and target_name not in analysis:
analysis[target_name] = {
"accessed_properties": set(),
"is_iterated": False,
"is_conditional": False,
"iteration_source": iter_target,
"iteration_target": None,
}
for child in node.iter_child_nodes():
visit_node(child)
visit_node(self.ast)
return analysis
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
"""
Get all properties accessed on a variable and its downstream assignments.
Args:
start_var: The starting variable to trace
Returns:
Dict mapping variable names to their accessed properties
"""
visited = set()
properties = {}
def trace_variable(var_name: str):
if var_name in visited:
return
visited.add(var_name)
# Get direct properties
if var_name in self.property_access:
properties[var_name] = self.property_access[var_name]
# Get properties from iteration targets
if var_name in self.iteration_targets:
target = self.iteration_targets[var_name]
if isinstance(target, str):
trace_variable(target)
elif isinstance(target, list):
for t in target:
trace_variable(t)
# Follow assignments
for target, source in self.variable_assignments.items():
if source == var_name:
trace_variable(target)
# Check for array slicing
analysis = self.analyze_template()
if var_name in analysis:
var_info = analysis[var_name]
if var_info["accessed_indices"]:
# If this variable is sliced, follow the resulting assignment
slice_result = f"{var_name}_slice"
if slice_result in self.property_access:
trace_variable(slice_result)
trace_variable(start_var)
return properties
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
"""
Get all properties accessed on messages and derived variables.
"""
all_properties = self.get_downstream_properties(field_messages)
# Combine all properties from all related variables
combined_properties = set()
for properties in all_properties.values():
combined_properties.update(properties)
# Also include properties from the message iteration variable
analysis = self.analyze_template()
if "message" in analysis:
combined_properties.update(analysis["message"]["accessed_properties"])
return combined_properties

View File

@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_property_mappings = ds_cfg.get("message_property_mappings")
message_field_role = (
message_property_mappings.get("role") if message_property_mappings else None
)
message_field_content = (
message_property_mappings.get("content") if message_property_mappings else None
)
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}

View File

@@ -38,7 +38,7 @@ def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
) -> str:
"""
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
@@ -70,7 +70,7 @@ def get_chat_template(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
return tokenizer.chat_template # type: ignore
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
@@ -78,7 +78,7 @@ def get_chat_template(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
return tokenizer.chat_template # type: ignore
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :

View File

@@ -18,6 +18,7 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -258,7 +259,7 @@ def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
) -> DictDefault:
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -268,6 +269,16 @@ def validate_config(
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
# Convert datasets to proper format if needed
if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
if capabilities or env_capabilities:
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None

View File

@@ -12,6 +12,7 @@ from pydantic import (
Field,
StringConstraints,
conlist,
field_serializer,
field_validator,
model_validator,
)
@@ -186,8 +187,13 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
message_field_role: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
@@ -199,9 +205,18 @@ class SFTDataset(BaseModel):
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def handle_legacy_message_fields(cls, data):
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
return handle_legacy_message_fields_logic(data)
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
if isinstance(data, BaseModel):
data = data.model_dump()
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
@@ -241,6 +256,7 @@ class DPODataset(BaseModel):
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
field_messages: Optional[str] = None
class StepwiseSupervisedDataset(BaseModel):
@@ -277,6 +293,9 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -680,17 +699,15 @@ class AxolotlInputConfig(
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
pretraining_dataset: Optional[conlist(Union[PretrainingDataset, SFTDataset], min_length=1)] = Field( # type: ignore
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
@@ -895,10 +912,15 @@ class AxolotlInputConfig(
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
if not ds_cfg.get("type"):
# Handle both dict and pydantic model cases
ds_type = (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
)
if not ds_type:
continue
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
@@ -910,6 +932,14 @@ class AxolotlInputConfig(
return datasets
@field_serializer("datasets")
def datasets_serializer(
self, ds_configs: Optional[List[DatasetConfig]]
) -> Optional[List[Dict[str, Any]]]:
if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):
@@ -1762,3 +1792,77 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
else:
data["torch_compile"] = False
return data
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""
Handle backwards compatibility between legacy message field mapping and new property mapping system.
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
- message_field_role: Mapped to the role field
- message_field_content: Mapped to the content field
The new system uses message_property_mappings to support arbitrary field mappings:
message_property_mappings:
role: source_role_field
content: source_content_field
additional_field: source_field
Args:
data: Dictionary containing configuration data
Returns:
Updated dictionary with message field mappings consolidated
Raises:
ValueError: If there are conflicts between legacy and new mappings
"""
data = data.copy() # Create a copy to avoid modifying the original
if data.get("message_property_mappings") is None:
data["message_property_mappings"] = {}
# Check for conflicts and handle role
if "message_field_role" in data:
LOG.warning(
"message_field_role is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
)
if (
"role" in data["message_property_mappings"]
and data["message_property_mappings"]["role"] != data["message_field_role"]
):
raise ValueError(
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
)
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
del data["message_field_role"]
elif "role" not in data["message_property_mappings"]:
data["message_property_mappings"]["role"] = "role"
# Check for conflicts and handle content
if "message_field_content" in data:
LOG.warning(
"message_field_content is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
)
if (
"content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"]
!= data["message_field_content"]
):
raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
)
data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"]
elif "content" not in data["message_property_mappings"]:
data["message_property_mappings"]["content"] = "content"
return data

View File

@@ -180,6 +180,7 @@ def load_tokenized_prepared_datasets(
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
ds_hash = str(
md5(
(

View File

@@ -13,3 +13,26 @@ class DictDefault(Dict):
def __or__(self, other):
return DictDefault(super().__ror__(other))
def __setitem__(self, name, value):
# workaround for pickle/unpickle issues and __frozen not being available
try:
isFrozen = hasattr( # pylint: disable=invalid-name
self, "__frozen"
) and object.__getattribute__(self, "__frozen")
except AttributeError:
isFrozen = False # pylint: disable=invalid-name
if isFrozen and name not in super().keys():
raise KeyError(name)
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
try:
p = object.__getattribute__(self, "__parent")
key = object.__getattribute__(self, "__key")
except AttributeError:
p = None
key = None
if p is not None:
p[key] = self
object.__delattr__(self, "__parent")
object.__delattr__(self, "__key")

View File

@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -109,6 +110,7 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -76,7 +76,9 @@ class TestFAXentropyLlama:
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -73,6 +73,8 @@ class TestReLoraLlama(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -108,6 +110,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -153,6 +157,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -198,6 +204,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -242,6 +250,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -289,6 +299,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -353,6 +365,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -56,6 +56,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -118,6 +120,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -157,6 +161,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
@@ -56,6 +56,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -99,6 +101,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -138,6 +142,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
@@ -69,6 +69,8 @@ class TestPretrainLlama:
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -62,6 +62,8 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,6 +59,8 @@ class TestLoraLlama(unittest.TestCase):
"max_steps": 20,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,6 +59,8 @@ class TestMamba(unittest.TestCase):
"save_safetensors": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,8 @@ class TestMistral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,6 +108,8 @@ class TestMistral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -69,6 +69,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -123,6 +125,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -180,6 +184,8 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -233,6 +239,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
@@ -281,6 +289,8 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
@@ -59,6 +59,8 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -103,6 +105,8 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -139,6 +143,8 @@ class TestCustomOptimizers(unittest.TestCase):
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
@@ -59,6 +59,8 @@ class TestPackedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -61,6 +61,7 @@ class TestPhi(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -40,8 +40,10 @@ class TestE2eQwen:
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"system": ["system"],
"user": ["user"],

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -66,6 +66,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -7,6 +7,7 @@ from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
@@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
modified_template = template.replace(old_date_logic, new_date_logic)
return modified_template
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
def fixture_chat_template_jinja_with_optional_fields() -> str:
return """{% for message in messages %}
{{'<|im_start|>'}}{{ message['role'] }}
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
{{ message['content'] }}{{'<|im_end|>'}}
{% endfor %}"""
@pytest.fixture(name="basic_jinja_template_analyzer")
def basic_jinja_template_analyzer():
return JinjaTemplateAnalyzer(
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'user' %}{{'<|user|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
' + message['content'] + '<|end|>
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
' }}{% else %}{{ eos_token }}{% endif %}"""
)
@pytest.fixture(name="mistral_jinja_template_analyzer")
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)

View File

@@ -38,6 +38,10 @@ class TestAssistantChatTemplateLlama3:
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"user": ["user"],
"assistant": ["assistant"],
@@ -74,8 +78,10 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -86,7 +92,7 @@ class TestAssistantChatTemplateLlama3:
train_on_inputs=False,
sequence_len=512,
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
@@ -114,8 +120,10 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=get_chat_template("phi_35"),
message_field_role="role",
message_field_content="content",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -126,7 +134,7 @@ class TestAssistantChatTemplateLlama3:
train_on_inputs=False,
sequence_len=512,
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -170,9 +178,11 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -185,7 +195,7 @@ class TestAssistantChatTemplateLlama3:
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
@@ -230,8 +240,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -239,7 +252,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["gpt"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -287,8 +300,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -296,7 +312,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -344,8 +360,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -353,7 +372,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["system", "human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -417,8 +436,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -486,8 +504,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,

View File

@@ -3,7 +3,6 @@ tests for chat_template prompt strategy
"""
import logging
import unittest
from copy import deepcopy
import pytest
@@ -123,15 +122,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -180,15 +179,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -241,20 +240,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant", "human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -307,15 +301,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -360,8 +354,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -369,7 +363,7 @@ class TestChatTemplateConfigurations:
roles_to_train=[],
train_on_eos="none", # Add this line
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
@@ -400,8 +394,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -409,7 +403,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="all",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -446,8 +440,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -455,7 +449,6 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="turn",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -526,8 +519,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -535,7 +528,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="last",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -578,8 +571,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -587,7 +580,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="none",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -624,15 +617,15 @@ class TestChatTemplateConfigurations:
chat_template, jinja_template=chat_template_jinja
),
drop_system_message=True,
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
@@ -668,8 +661,7 @@ class TestChatTemplateConfigurations:
chat_template, jinja_template=chat_template_jinja
),
roles=custom_roles,
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -741,8 +733,7 @@ class TestChatTemplateConfigurations:
),
message_field_training="train",
message_field_training_detail="train_detail",
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -911,6 +902,64 @@ class TestChatTemplateConfigurations:
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
def test_get_chat_template_variables(
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
):
LOG.info("Testing get_chat_template_variables")
if __name__ == "__main__":
unittest.main()
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
prompter = ChatTemplatePrompter(
actual_tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=actual_jinja_template
),
message_property_mappings={"from": "role", "value": "content"},
)
variables = prompter.get_chat_template_msg_variables(
actual_jinja_template
if actual_jinja_template
else actual_tokenizer.get_chat_template(),
"messages",
)
if chat_template == "llama3":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "chatml":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "phi_35":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
else:
LOG.warning(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
raise ValueError(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)

View File

@@ -0,0 +1,159 @@
"""
tests for jinja_template_analyzer
"""
import logging
import pytest
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestJinjaTemplateAnalyzer:
"""
tests for jinja_template_analyzer
"""
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = basic_jinja_template_analyzer.get_template_variables()
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
assert set(variables.keys()) == expected_vars
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = mistral_jinja_template_analyzer.get_template_variables()
expected_vars = {
"messages",
"content",
"eos_token",
"message",
"tools",
"system_message",
"loop_messages",
"ns",
"tool_call",
"tool",
"loop",
"bos_token",
"raise_exception",
}
assert set(variables.keys()) == expected_vars
message_vars = variables["message"]
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
def test_message_property_access(self, basic_jinja_template_analyzer):
"""Test that properties accessed on 'message' variable are correctly identified."""
LOG.info("Testing message property access")
variables = basic_jinja_template_analyzer.get_template_variables()
assert "messages" in variables
assert "message" in variables
assert "role" in variables["message"]
assert "content" in variables["message"]
def test_detailed_analysis(self, basic_jinja_template_analyzer):
"""Test the detailed analysis of variable usage."""
LOG.info("Testing detailed analysis")
analysis = basic_jinja_template_analyzer.analyze_template()
assert analysis["messages"]["is_iterated"] is True
assert "role" in analysis["message"]["accessed_properties"]
assert "content" in analysis["message"]["accessed_properties"]
assert analysis["add_generation_prompt"]["is_conditional"] is True
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
assert not analysis["eos_token"]["is_iterated"]
assert len(analysis["eos_token"]["accessed_properties"]) == 0
def test_nested_property_access(self):
"""Test handling of nested property access."""
LOG.info("Testing nested property access")
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "profile" in variables["user"]
assert "settings" in variables["user"]
def test_loop_variable_handling(self):
"""Test handling of loop variables and their properties."""
LOG.info("Testing loop variable handling")
template = """
{% for item in items %}
{{ item.name }}
{% for subitem in item.subitems %}
{{ subitem.value }}
{% endfor %}
{% endfor %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["items"]["is_iterated"]
assert "name" in analysis["item"]["accessed_properties"]
assert "subitems" in analysis["item"]["accessed_properties"]
def test_conditional_variable_usage(self):
"""Test detection of variables used in conditional statements."""
LOG.info("Testing conditional variable usage")
template = """
{% if user.is_admin and config.debug_mode %}
{{ debug_info }}
{% endif %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["user"]["is_conditional"]
assert analysis["config"]["is_conditional"]
assert "is_admin" in analysis["user"]["accessed_properties"]
assert "debug_mode" in analysis["config"]["accessed_properties"]
def test_complex_expressions(self):
"""Test handling of complex expressions and filters."""
LOG.info("Testing complex expressions and filters")
template = """
{{ user.name | upper }}
{{ messages | length > 0 and messages[0].content }}
{{ data['key'].nested['value'] }}
"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "name" in variables["user"]
assert "messages" in variables
assert "content" in variables["messages"]
assert "data" in variables
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
"""Test that the basic message variables are correctly identified."""
LOG.info("Testing basic message variables")
variables = basic_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content"}
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
"""Test that the mixtral message variables are correctly identified."""
LOG.info("Testing mixtral message variables")
variables = mistral_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -302,3 +302,22 @@ class TestValidationCheckDatasetConfig(BaseValidation):
)
validate_config(cfg)
def test_message_property_mappings(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"message_property_mappings": {
"role": "role",
"content": "content",
},
}
],
}
)
validate_config(cfg)

View File

@@ -76,7 +76,7 @@ class TestModelsUtils:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="")
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
@@ -116,3 +116,79 @@ class TestModelsUtils:
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)
def test_message_property_mapping(self):
"""Test message property mapping configuration validation"""
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
# Test legacy fields are mapped orrectly
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_field_content="content_field",
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test direct message_property_mapping works
dataset = SFTDataset(
path="test_path",
message_property_mappings={
"role": "custom_role",
"content": "custom_content",
},
)
assert dataset.message_property_mappings == {
"role": "custom_role",
"content": "custom_content",
}
# Test both legacy and new fields work when they match
dataset = SFTDataset(
path="test_path",
message_field_role="same_role",
message_property_mappings={"role": "same_role"},
)
assert dataset.message_property_mappings == {
"role": "same_role",
"content": "content",
}
# Test both legacy and new fields work when they don't overlap
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_property_mappings={"content": "content_field"},
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test no role or content provided
dataset = SFTDataset(
path="test_path",
)
assert dataset.message_property_mappings == {
"role": "role",
"content": "content",
}
# Test error when legacy and new fields conflict
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_role="legacy_role",
message_property_mappings={"role": "different_role"},
)
assert "Conflicting message role fields" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_content="legacy_content",
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)