feat: add config for optional parameters in a chat message (#2260)

* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
NJordan72
2025-02-17 21:59:27 -05:00
committed by GitHub
parent 3aac3b1da9
commit b194e17c28
51 changed files with 1190 additions and 230 deletions

View File

@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None

View File

@@ -34,15 +34,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
max_length = self.prompter.max_length
self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
@@ -55,17 +52,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
:max_length
]
self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
def __init__(
self,
tokenizer,
chat_template: str,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "role",
message_field_content: str = "content",
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
message_property_mappings = {
"role": "role",
"content": "content",
}
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
"tool": "tool",
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template, field_messages
)
self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
return adjusted_details
def get_chat_template_msg_variables(
self, chat_template: str, field_messages: str
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(field_messages)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
_messages = "messages"
def __init__(
self,
prompter: ChatTemplatePrompter,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs,
sequence_len,
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
@property
def supports_batched(self) -> bool:
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
@@ -464,30 +485,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def get_conversation_thread(self, prompt):
turns = []
optional_keys = [
"tool_calls", # tool that 'assistant' calls
"name", # name of tool given by 'tool'
"tool_call_id", # mistral/mixtral requires this
]
for message in prompt[self.messages]:
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
turn = {
"role": self.prompter.roles[message[self.prompter.message_field_role]],
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
# do not add content if None as it may conflict with some templates due to tools
content = message.get(self.prompter.message_field_content, None)
if content is not None:
turn["content"] = content
for key in optional_keys:
value = message.get(key, None)
if value is not None:
turn[key] = value
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":
@@ -495,6 +503,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return turns
def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
if message.get(value) is not None:
transformed_message[key] = message[value]
else:
LOG.debug(
f"Could not find value for property {value} in message: {message}"
)
# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val
return transformed_message
def get_images(self, prompt):
return prompt.get(self.images, None)
@@ -516,33 +555,46 @@ class StrategyLoader:
}
def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
if ds_cfg is None:
dataset_config = {}
elif isinstance(ds_cfg, BaseModel):
dataset_config = ds_cfg.model_dump()
else:
dataset_config = ds_cfg
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get(
"message_field_training", None
),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
"field_messages": dataset_config.get("field_messages", "messages"),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
@@ -551,9 +603,6 @@ class StrategyLoader:
**strategy_params,
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
message_property_mappings = ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
)
role_map_inv = ds_cfg.get(
"roles",
{
@@ -40,18 +48,18 @@ def default(
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
"role": role_map[m[message_property_mappings["role"]]],
"content": m[message_property_mappings["content"]],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
"content": sample[field_chosen][message_property_mappings["content"]],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
"content": sample[field_rejected][message_property_mappings["content"]],
}
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}

View File

@@ -0,0 +1,318 @@
"""Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes
class JinjaTemplateAnalysis(TypedDict):
"""
Represents the detailed analysis of a Jinja template variable.
Attributes:
accessed_properties (Set[str]): A set of properties accessed from the variable
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
"""
accessed_properties: Set[str]
accessed_indices: Set[Union[int, float]]
is_iterated: bool
is_conditional: bool
iteration_source: Optional[str]
iteration_target: Optional[Union[str, list[str]]]
class JinjaTemplateAnalyzer:
"""
Analyzes Jinja templates to extract information about variable usage,
including accessed properties, iteration, and conditional references.
Attributes:
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
Methods:
get_template_variables(template: str) -> Dict[str, Set[str]]:
Parse a Jinja template and return a mapping of variables to their accessed properties.
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
Perform a detailed analysis of the template, including variable usage,
iteration, and conditional references.
Private Methods:
_visit_node(node) -> None:
Recursively visit AST nodes to detect attribute access and iteration targets.
_get_base_name(node) -> Optional[str]:
Extract the base variable name from a node.
_get_target_name(node) -> Optional[Union[str, list[str]]]:
Extract the target name(s) from a `For` node.
"""
def __init__(self, template: str):
self.env: Environment = Environment(autoescape=True)
self.property_access: Dict[str, Set[str]] = {}
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
self.index_access: Dict[str, Set[Union[int, float]]] = {}
self.ast: nodes.Node = self.env.parse(template)
self.template: str = template
self.variable_assignments: Dict[str, str] = {}
def _visit_node(self, node) -> None:
"""Recursively visit AST nodes to find attribute access."""
# Handle attribute access (dot notation)
if isinstance(node, nodes.Getattr):
base_name = self._get_base_name(node.node)
if base_name:
self.property_access.setdefault(base_name, set()).add(node.attr)
# Handle dictionary access (subscript notation)
elif isinstance(node, nodes.Getitem):
base_name = self._get_base_name(node.node)
if base_name and isinstance(node.arg, nodes.Const):
value = node.arg.value
if isinstance(value, (int, float)):
self.index_access.setdefault(base_name, set()).add(value)
else:
self.property_access.setdefault(base_name, set()).add(value)
elif isinstance(node, nodes.Test) and node.name == "defined":
base_name = self._get_base_name(node.node)
if base_name:
if isinstance(node.node, nodes.Getattr):
self.property_access.setdefault(base_name, set()).add(
node.node.attr
)
# Handle loop variables
elif isinstance(node, nodes.For):
iter_name = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_name and target_name:
self.iteration_targets[target_name] = iter_name
self.property_access.setdefault(iter_name, set())
elif isinstance(node, nodes.Assign):
target_name = self._get_target_name(node.target)
source_name = self._get_base_name(node.node)
if target_name and source_name:
self.variable_assignments[target_name] = source_name
elif isinstance(node, nodes.Filter):
if node.name == "selectattr":
target = self._get_base_name(node.node)
if target:
self.variable_assignments[f"filtered_{target}"] = target
for child in node.iter_child_nodes():
self._visit_node(child)
def _get_target_name(self, node) -> Optional[str]:
"""Get the target variable name from a For node.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
- str: For simple variable targets (e.g., "item" in "for item in items")
- None: If the node type is not recognized or is a tuple
"""
if isinstance(node, nodes.Name):
return node.name
return None
def _get_target_names(self, node) -> list[str]:
"""Get all target variable names from a For node, including tuple unpacking.
Args:
node: A Jinja AST node representing either a Name or Tuple node
Returns:
List of target variable names
"""
if isinstance(node, nodes.Name):
return [node.name]
if isinstance(node, nodes.Tuple):
names = []
for n in node.items:
if isinstance(n, nodes.Name):
names.append(n.name)
return names
return []
def _get_base_name(self, node) -> Optional[str]:
"""Get the base variable name from a node."""
if isinstance(node, nodes.Name):
return node.name
if isinstance(node, nodes.Getattr):
return self._get_base_name(node.node)
if isinstance(node, nodes.Getitem):
return self._get_base_name(node.node)
return None
def get_template_variables(self) -> Dict[str, Set[str]]:
"""
Parse a Jinja template and return both variables and their accessed properties.
Args:
template (str): The Jinja template string
Returns:
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
"""
# Parse the template
ast = self.env.parse(self.template)
# Get all undeclared variables
variables = meta.find_undeclared_variables(ast)
# Reset property access tracking
self.property_access = {}
# Visit all nodes to find property access
self._visit_node(ast)
# Create result dictionary
result: Dict[str, Set[str]] = {var: set() for var in variables}
# Merge in any discovered sub-properties
for var, props in self.property_access.items():
if var not in result:
result[var] = set()
result[var].update(props)
return result
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
"""
Provide a detailed analysis of template variables and their usage.
"""
variables = self.get_template_variables()
self.iteration_targets = {}
analysis: Dict[str, JinjaTemplateAnalysis] = {
var: JinjaTemplateAnalysis(
accessed_properties=props,
accessed_indices=set(),
is_iterated=False,
is_conditional=False,
iteration_source=None,
iteration_target=None,
)
for var, props in variables.items()
}
for var, indices in self.index_access.items():
if var in analysis:
analysis[var]["accessed_indices"] = indices
def visit_node(node):
if isinstance(node, nodes.If):
def find_test_vars(test_node):
if isinstance(test_node, nodes.Name):
if test_node.name in analysis:
analysis[test_node.name]["is_conditional"] = True
for child in test_node.iter_child_nodes():
find_test_vars(child)
find_test_vars(node.test)
if isinstance(node, nodes.For):
iter_target = self._get_base_name(node.iter)
target_name = self._get_target_name(node.target)
if iter_target in analysis:
analysis[iter_target]["is_iterated"] = True
if target_name:
analysis[iter_target]["iteration_target"] = target_name
if isinstance(target_name, str) and target_name not in analysis:
analysis[target_name] = {
"accessed_properties": set(),
"is_iterated": False,
"is_conditional": False,
"iteration_source": iter_target,
"iteration_target": None,
}
for child in node.iter_child_nodes():
visit_node(child)
visit_node(self.ast)
return analysis
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
"""
Get all properties accessed on a variable and its downstream assignments.
Args:
start_var: The starting variable to trace
Returns:
Dict mapping variable names to their accessed properties
"""
visited = set()
properties = {}
def trace_variable(var_name: str):
if var_name in visited:
return
visited.add(var_name)
# Get direct properties
if var_name in self.property_access:
properties[var_name] = self.property_access[var_name]
# Get properties from iteration targets
if var_name in self.iteration_targets:
target = self.iteration_targets[var_name]
if isinstance(target, str):
trace_variable(target)
elif isinstance(target, list):
for t in target:
trace_variable(t)
# Follow assignments
for target, source in self.variable_assignments.items():
if source == var_name:
trace_variable(target)
# Check for array slicing
analysis = self.analyze_template()
if var_name in analysis:
var_info = analysis[var_name]
if var_info["accessed_indices"]:
# If this variable is sliced, follow the resulting assignment
slice_result = f"{var_name}_slice"
if slice_result in self.property_access:
trace_variable(slice_result)
trace_variable(start_var)
return properties
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
"""
Get all properties accessed on messages and derived variables.
"""
all_properties = self.get_downstream_properties(field_messages)
# Combine all properties from all related variables
combined_properties = set()
for properties in all_properties.values():
combined_properties.update(properties)
# Also include properties from the message iteration variable
analysis = self.analyze_template()
if "message" in analysis:
combined_properties.update(analysis["message"]["accessed_properties"])
return combined_properties

View File

@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ds_cfg = ds_cfg or {}
field_messages = ds_cfg.get("field_messages")
message_field_role = ds_cfg.get("message_field_role")
message_field_content = ds_cfg.get("message_field_content")
message_property_mappings = ds_cfg.get("message_property_mappings")
message_field_role = (
message_property_mappings.get("role") if message_property_mappings else None
)
message_field_content = (
message_property_mappings.get("content") if message_property_mappings else None
)
message_field_training = ds_cfg.get("message_field_training")
builder_kwargs = {}

View File

@@ -38,7 +38,7 @@ def get_chat_template(
user_choice: str,
jinja_template: Optional[str] = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
):
) -> str:
"""
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
@@ -70,7 +70,7 @@ def get_chat_template(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template
return tokenizer.chat_template # type: ignore
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
@@ -78,7 +78,7 @@ def get_chat_template(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template
return tokenizer.chat_template # type: ignore
user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :

View File

@@ -18,6 +18,7 @@ from axolotl.utils.config.models.input.v0_4_1 import (
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -258,7 +259,7 @@ def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
) -> DictDefault:
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -268,6 +269,16 @@ def validate_config(
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
# Convert datasets to proper format if needed
if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
elif not isinstance(ds_cfg, SFTDataset):
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
if capabilities or env_capabilities:
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None

View File

@@ -12,6 +12,7 @@ from pydantic import (
Field,
StringConstraints,
conlist,
field_serializer,
field_validator,
model_validator,
)
@@ -186,8 +187,13 @@ class SFTDataset(BaseModel):
field_human: Optional[str] = None
field_model: Optional[str] = None
field_messages: Optional[str] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
message_field_role: Optional[
str
] = None # deprecated, use message_property_mappings
message_field_content: Optional[
str
] = None # deprecated, use message_property_mappings
message_property_mappings: Optional[Dict[str, str]] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
@@ -199,9 +205,18 @@ class SFTDataset(BaseModel):
trust_remote_code: Optional[bool] = False
revision: Optional[str] = None
@model_validator(mode="before")
@classmethod
def handle_legacy_message_fields(cls, data):
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
return handle_legacy_message_fields_logic(data)
@model_validator(mode="before")
@classmethod
def check_chat_template_config(cls, data):
if isinstance(data, BaseModel):
data = data.model_dump()
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
@@ -241,6 +256,7 @@ class DPODataset(BaseModel):
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
revision: Optional[str] = None
field_messages: Optional[str] = None
class StepwiseSupervisedDataset(BaseModel):
@@ -277,6 +293,9 @@ class KTODataset(BaseModel):
revision: Optional[str] = None
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
@@ -680,17 +699,15 @@ class AxolotlInputConfig(
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
dpo_use_logits_to_keep: Optional[bool] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
pretraining_dataset: Optional[conlist(Union[PretrainingDataset, SFTDataset], min_length=1)] = Field( # type: ignore
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
)
@@ -895,10 +912,15 @@ class AxolotlInputConfig(
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
if not ds_cfg.get("type"):
# Handle both dict and pydantic model cases
ds_type = (
ds_cfg.get("type")
if isinstance(ds_cfg, dict)
else getattr(ds_cfg, "type", None)
)
if not ds_type:
continue
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
@@ -910,6 +932,14 @@ class AxolotlInputConfig(
return datasets
@field_serializer("datasets")
def datasets_serializer(
self, ds_configs: Optional[List[DatasetConfig]]
) -> Optional[List[Dict[str, Any]]]:
if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):
@@ -1762,3 +1792,77 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
else:
data["torch_compile"] = False
return data
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""
Handle backwards compatibility between legacy message field mapping and new property mapping system.
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
- message_field_role: Mapped to the role field
- message_field_content: Mapped to the content field
The new system uses message_property_mappings to support arbitrary field mappings:
message_property_mappings:
role: source_role_field
content: source_content_field
additional_field: source_field
Args:
data: Dictionary containing configuration data
Returns:
Updated dictionary with message field mappings consolidated
Raises:
ValueError: If there are conflicts between legacy and new mappings
"""
data = data.copy() # Create a copy to avoid modifying the original
if data.get("message_property_mappings") is None:
data["message_property_mappings"] = {}
# Check for conflicts and handle role
if "message_field_role" in data:
LOG.warning(
"message_field_role is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
)
if (
"role" in data["message_property_mappings"]
and data["message_property_mappings"]["role"] != data["message_field_role"]
):
raise ValueError(
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
)
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
del data["message_field_role"]
elif "role" not in data["message_property_mappings"]:
data["message_property_mappings"]["role"] = "role"
# Check for conflicts and handle content
if "message_field_content" in data:
LOG.warning(
"message_field_content is deprecated, use message_property_mappings instead. "
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
)
if (
"content" in data["message_property_mappings"]
and data["message_property_mappings"]["content"]
!= data["message_field_content"]
):
raise ValueError(
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
)
data["message_property_mappings"]["content"] = (
data["message_field_content"] or "content"
)
del data["message_field_content"]
elif "content" not in data["message_property_mappings"]:
data["message_property_mappings"]["content"] = "content"
return data

View File

@@ -180,6 +180,7 @@ def load_tokenized_prepared_datasets(
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
ds_hash = str(
md5(
(

View File

@@ -13,3 +13,26 @@ class DictDefault(Dict):
def __or__(self, other):
return DictDefault(super().__ror__(other))
def __setitem__(self, name, value):
# workaround for pickle/unpickle issues and __frozen not being available
try:
isFrozen = hasattr( # pylint: disable=invalid-name
self, "__frozen"
) and object.__getattribute__(self, "__frozen")
except AttributeError:
isFrozen = False # pylint: disable=invalid-name
if isFrozen and name not in super().keys():
raise KeyError(name)
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
try:
p = object.__getattribute__(self, "__parent")
key = object.__getattribute__(self, "__key")
except AttributeError:
p = None
key = None
if p is not None:
p[key] = self
object.__delattr__(self, "__parent")
object.__delattr__(self, "__key")