* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
610 lines
22 KiB
Python
610 lines
22 KiB
Python
"""
|
|
HF Chat Templates prompt strategy
|
|
"""
|
|
|
|
import logging
|
|
from collections import defaultdict
|
|
from typing import Any, Dict, List, Optional, Set, Union
|
|
|
|
from pydantic import BaseModel
|
|
from transformers import ProcessorMixin
|
|
|
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
|
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
|
|
|
# Configure the logger
|
|
LOG = logging.getLogger("axolotl")
|
|
LOG.setLevel(logging.INFO)
|
|
|
|
|
|
class ChatTemplatePrompter(Prompter):
|
|
"""Prompter for HF chat templates"""
|
|
|
|
def __init__(
|
|
self,
|
|
tokenizer,
|
|
chat_template: str,
|
|
processor=None,
|
|
max_length=2048,
|
|
message_property_mappings: Optional[Dict[str, str]] = None,
|
|
message_field_training: Optional[str] = None,
|
|
message_field_training_detail: Optional[str] = None,
|
|
field_messages: str = "messages",
|
|
roles: Optional[Dict[str, List[str]]] = None,
|
|
drop_system_message: bool = False,
|
|
):
|
|
# check if message_property_mappings is None or empty dict
|
|
if message_property_mappings is None or (not message_property_mappings):
|
|
message_property_mappings = {
|
|
"role": "role",
|
|
"content": "content",
|
|
}
|
|
|
|
if roles:
|
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
|
else:
|
|
self.roles = {
|
|
"human": "user",
|
|
"user": "user",
|
|
"assistant": "assistant",
|
|
"gpt": "assistant",
|
|
"system": "system",
|
|
"tool": "tool",
|
|
}
|
|
|
|
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
|
chat_template, field_messages
|
|
)
|
|
self.message_property_mappings = message_property_mappings
|
|
self.message_field_training = message_field_training
|
|
self.message_field_training_detail = message_field_training_detail
|
|
self.field_messages = field_messages
|
|
self.tokenizer = tokenizer
|
|
self.processor: Optional[ProcessorMixin] = processor
|
|
self.chat_template = chat_template
|
|
self.max_length = max_length
|
|
self.drop_system_message = drop_system_message
|
|
|
|
@property
|
|
def chat_template_msg_variables(self) -> Set[str]:
|
|
return self._chat_template_msg_variables
|
|
|
|
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
|
if self.processor:
|
|
if not callable(self.processor):
|
|
raise TypeError("Processor must be callable")
|
|
|
|
text = self.processor.apply_chat_template(
|
|
conversation,
|
|
chat_template=self.chat_template,
|
|
tokenize=False,
|
|
add_generation_prompt=add_generation_prompt,
|
|
)
|
|
batch = self.processor(
|
|
text=text,
|
|
images=images,
|
|
return_tensors="pt",
|
|
)
|
|
# workaround since processor works in batches instead of single examples
|
|
for k, val in batch.items():
|
|
if k in ["pixel_values"]:
|
|
batch[k] = val.tolist()
|
|
else:
|
|
batch[k] = val.squeeze().tolist()
|
|
return batch
|
|
|
|
return self.tokenizer.apply_chat_template(
|
|
conversation,
|
|
add_generation_prompt=add_generation_prompt,
|
|
chat_template=self.chat_template,
|
|
)
|
|
|
|
def get_offsets_for_train_detail(
|
|
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
|
|
) -> List[int]:
|
|
tokenized_output = self.tokenizer(
|
|
text, return_offsets_mapping=True, add_special_tokens=False
|
|
)
|
|
tokens = tokenized_output.tokens()
|
|
token_offsets = tokenized_output["offset_mapping"]
|
|
|
|
LOG.debug(f"Tokenizing text: {text}")
|
|
LOG.debug(f"Tokens: {tokens}")
|
|
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
|
|
for i in range(len(token_offsets) - 1):
|
|
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
|
|
# Ensure the last token's end offset is set correctly
|
|
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
|
|
LOG.debug(f"Token offsets: {token_offsets}")
|
|
|
|
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
|
|
result = [IGNORE_TOKEN_ID] * len(token_offsets)
|
|
|
|
# Adjust train_details to align with token boundaries
|
|
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
|
|
|
|
for idx, (start, end) in enumerate(token_offsets):
|
|
for detail in adjusted_train_details:
|
|
# Check if the token is completely within the detail's range
|
|
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
|
|
if detail["train"] or not mask_untrainable:
|
|
result[idx] = start
|
|
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
|
|
else:
|
|
LOG.debug(
|
|
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
|
|
)
|
|
elif start < detail["end_offset"] and end > detail["begin_offset"]:
|
|
# Token partially overlaps with detail, always mark as non-trainable
|
|
LOG.debug(
|
|
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
|
|
)
|
|
|
|
LOG.debug(f"Final result: {result}")
|
|
return result
|
|
|
|
def adjust_train_details(
|
|
self, train_details: List[Dict], token_offsets: List[tuple]
|
|
) -> List[Dict]:
|
|
adjusted_details = []
|
|
for detail in train_details:
|
|
begin_offset = detail["begin_offset"]
|
|
end_offset = detail["end_offset"]
|
|
|
|
# Find the first token that starts after or at the begin_offset
|
|
begin_token = next(
|
|
(
|
|
i
|
|
for i, (t_start, t_end) in enumerate(token_offsets)
|
|
if t_start >= begin_offset
|
|
),
|
|
len(token_offsets),
|
|
)
|
|
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
|
|
begin_token -= 1
|
|
|
|
# Find the last token that ends before or at the end_offset
|
|
end_token = next(
|
|
(
|
|
i
|
|
for i in range(len(token_offsets) - 1, -1, -1)
|
|
if token_offsets[i][1] <= end_offset
|
|
),
|
|
-1,
|
|
)
|
|
if (
|
|
end_token < len(token_offsets) - 1
|
|
and token_offsets[end_token + 1][0] < end_offset
|
|
):
|
|
end_token += 1
|
|
|
|
if begin_token <= end_token:
|
|
adjusted_begin = token_offsets[begin_token][0]
|
|
adjusted_end = token_offsets[end_token][1]
|
|
|
|
if adjusted_begin != begin_offset or adjusted_end != end_offset:
|
|
LOG.warning(
|
|
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
|
|
)
|
|
|
|
adjusted_details.append(
|
|
{
|
|
"begin_offset": adjusted_begin,
|
|
"end_offset": adjusted_end,
|
|
"train": detail["train"],
|
|
}
|
|
)
|
|
else:
|
|
LOG.warning(
|
|
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
|
|
)
|
|
|
|
return adjusted_details
|
|
|
|
def get_chat_template_msg_variables(
|
|
self, chat_template: str, field_messages: str
|
|
) -> Set[str]:
|
|
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
|
return template_analyzer.get_message_vars(field_messages)
|
|
|
|
|
|
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|
"""
|
|
Tokenizing strategy for instruction-based prompts.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
prompter: "ChatTemplatePrompter",
|
|
tokenizer,
|
|
train_on_inputs,
|
|
sequence_len,
|
|
roles_to_train=None,
|
|
train_on_eos=None,
|
|
):
|
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
|
self.prompter: ChatTemplatePrompter = prompter
|
|
|
|
self.roles_to_train = []
|
|
if roles_to_train:
|
|
# map roles if exist in prompter.roles else use the role as is
|
|
self.roles_to_train = [
|
|
prompter.roles.get(role, role) for role in roles_to_train
|
|
]
|
|
|
|
self.train_on_eos = train_on_eos
|
|
self.images = "images"
|
|
|
|
LOG.debug(
|
|
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
|
)
|
|
|
|
@property
|
|
def supports_batched(self) -> bool:
|
|
# Let calling code know we can handle lists of examples
|
|
return True
|
|
|
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
|
try:
|
|
return all(isinstance(v, list) for v in prompt.values()) and all(
|
|
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
|
)
|
|
except KeyError:
|
|
return False
|
|
|
|
def tokenize_prompt(self, prompt: dict[str, Any]):
|
|
"""
|
|
Public method that can handle either a single prompt or a batch of prompts.
|
|
"""
|
|
|
|
if not self.is_prompt_batched(prompt) or not self.supports_batched:
|
|
return self._tokenize_single_prompt(prompt)
|
|
|
|
res = defaultdict(lambda: [])
|
|
feature_names = list(prompt.keys())
|
|
|
|
# Process each prompt individually
|
|
for row in zip(*prompt.values()):
|
|
tokenized_prompt = self._tokenize_single_prompt(
|
|
dict(zip(feature_names, row))
|
|
)
|
|
for key, val in tokenized_prompt.items():
|
|
for i in range(0, len(val), self.sequence_len):
|
|
res[key].append(val[i : i + self.sequence_len])
|
|
|
|
# If there are no examples left, return an empty dictionary
|
|
if not res:
|
|
return {}
|
|
|
|
return dict(res)
|
|
|
|
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
|
|
# Old simple legacy behavior that works reliably.
|
|
if (
|
|
not self.roles_to_train
|
|
and not self.train_on_eos
|
|
and not self.prompter.message_field_training # type: ignore
|
|
and not self.prompter.message_field_training_detail # type: ignore
|
|
):
|
|
turns = self.get_conversation_thread(prompt)
|
|
images = self.get_images(prompt)
|
|
prompt_ids = self.prompter.build_prompt( # type: ignore
|
|
turns[:-1],
|
|
add_generation_prompt=True,
|
|
images=images,
|
|
)
|
|
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
|
|
tokenized_prompt = {}
|
|
if isinstance(tokenized_res, list):
|
|
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
|
tokenized_prompt["input_ids"] = input_ids
|
|
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
|
|
else:
|
|
input_ids = tokenized_res["input_ids"]
|
|
tokenized_prompt = tokenized_res
|
|
|
|
if not self.train_on_inputs:
|
|
user_prompt_len = len(prompt_ids)
|
|
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
|
else:
|
|
labels = input_ids
|
|
|
|
tokenized_prompt["labels"] = labels
|
|
|
|
return tokenized_prompt
|
|
|
|
turns = self.get_conversation_thread(prompt)
|
|
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
|
|
|
last_eos_idx = -1
|
|
for index, turn in enumerate(turns):
|
|
role = turn.get("role")
|
|
content = turn.get("content")
|
|
train_turn = turn.get("training")
|
|
train_detail = turn.get("training_detail")
|
|
|
|
LOG.debug(
|
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
|
)
|
|
|
|
should_train = None
|
|
if train_turn is not None:
|
|
should_train = train_turn
|
|
elif train_detail is not None:
|
|
should_train = bool(train_detail)
|
|
else:
|
|
should_train = self.train_on_inputs or role in self.roles_to_train
|
|
|
|
LOG.debug(f"Should train: {should_train}")
|
|
|
|
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
|
|
|
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
|
|
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
|
if train_detail:
|
|
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
|
content, train_detail
|
|
)
|
|
LOG.debug(f"Token offsets: {token_offsets}")
|
|
for i, offset in enumerate(token_offsets):
|
|
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
|
|
input_ids
|
|
):
|
|
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
|
|
LOG.debug(
|
|
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
|
)
|
|
else:
|
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
|
turn_start_idx:turn_end_idx
|
|
]
|
|
LOG.debug(
|
|
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
|
|
)
|
|
|
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
|
|
|
# Handle EOS token
|
|
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
|
|
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
|
|
last_eos_idx = eos_idx
|
|
if self.train_on_eos == "all" or (
|
|
self.train_on_eos == "turn" and should_train
|
|
):
|
|
labels[eos_idx] = input_ids[eos_idx]
|
|
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
|
else:
|
|
LOG.debug(
|
|
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
|
)
|
|
|
|
# Handle 'last' option for train_on_eos
|
|
if self.train_on_eos == "last" and last_eos_idx != -1:
|
|
labels[last_eos_idx] = input_ids[last_eos_idx]
|
|
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
|
|
|
LOG.debug(f"Final labels: {labels}")
|
|
|
|
return {
|
|
"input_ids": input_ids,
|
|
"labels": labels,
|
|
"attention_mask": [1] * len(input_ids),
|
|
}
|
|
|
|
def find_first_eos_token(self, input_ids, start_idx):
|
|
eos_token_id = self.tokenizer.eos_token_id
|
|
for i in range(start_idx, len(input_ids)):
|
|
if input_ids[i] == eos_token_id:
|
|
return i
|
|
return -1
|
|
|
|
def find_turn(self, turns: list[dict], turn_idx: int):
|
|
"""
|
|
Locate the starting and ending indices of the specified turn in a conversation.
|
|
"""
|
|
# pylint: disable=too-many-return-statements
|
|
|
|
if turn_idx >= len(turns):
|
|
raise ValueError(f"Turn index {turn_idx} out of range")
|
|
|
|
# mistral does not output message if it contains only system message
|
|
if (
|
|
turn_idx == 0
|
|
and turns[0].get("role") == "system"
|
|
and "mistral" in self.tokenizer.name_or_path.lower()
|
|
):
|
|
return -1, -1
|
|
|
|
empty_turn = {
|
|
"role": turns[turn_idx].get("role"),
|
|
"content": "[[dummy_message]]",
|
|
}
|
|
|
|
# Create conversation versions
|
|
turns_with_empty = turns[:turn_idx] + [empty_turn]
|
|
turns_with_content = turns[: turn_idx + 1]
|
|
|
|
# Generate the conversation up to the turn, with final turn replaced with dummy content
|
|
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
|
|
|
|
# Generate the conversation up to the turn, with final turn included
|
|
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
|
|
|
|
if not full_ids or not dummy_ids:
|
|
LOG.warning(f"Empty template generated for turn {turn_idx}")
|
|
return -1, -1
|
|
|
|
# Find first difference (start of content)
|
|
start_idx = None
|
|
min_len = min(len(dummy_ids), len(full_ids))
|
|
for i in range(min_len):
|
|
if dummy_ids[i] != full_ids[i]:
|
|
start_idx = i
|
|
break
|
|
|
|
if start_idx is None:
|
|
LOG.warning(f"Could not find content start boundary for turn {turn_idx}")
|
|
return -1, -1
|
|
|
|
# Find last difference (end of content)
|
|
end_idx = None
|
|
for i in range(min_len):
|
|
dummy_pos = len(dummy_ids) - 1 - i
|
|
full_pos = len(full_ids) - 1 - i
|
|
if dummy_ids[dummy_pos] != full_ids[full_pos]:
|
|
end_idx = full_pos + 1 # Add one to include the last token when slice
|
|
break
|
|
|
|
if end_idx is None:
|
|
LOG.warning(f"Could not find content end boundary for turn {turn_idx}")
|
|
return -1, -1
|
|
|
|
if end_idx < start_idx:
|
|
LOG.warning(
|
|
f"Content end boundary is before start boundary for turn {turn_idx}"
|
|
)
|
|
return -1, -1
|
|
|
|
if end_idx == start_idx:
|
|
LOG.warning(
|
|
f"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn."
|
|
)
|
|
return -1, -1
|
|
|
|
LOG.debug(f"Content boundaries: {start_idx}, {end_idx}")
|
|
LOG.debug(
|
|
f"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}"
|
|
)
|
|
|
|
return start_idx, end_idx
|
|
|
|
def get_conversation_thread(self, prompt):
|
|
turns = []
|
|
for message in prompt[self.prompter.field_messages]:
|
|
transformed_message = self.transform_message(message)
|
|
|
|
turn = {
|
|
**transformed_message,
|
|
"training": message.get(self.prompter.message_field_training),
|
|
"training_detail": message.get(
|
|
self.prompter.message_field_training_detail
|
|
),
|
|
}
|
|
|
|
turns.append(turn)
|
|
|
|
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
|
turns = turns[1:]
|
|
|
|
return turns
|
|
|
|
def transform_message(self, message):
|
|
# Build the initial transformed message from the mappings
|
|
transformed_message = {}
|
|
for key, value in self.prompter.message_property_mappings.items():
|
|
if message.get(value) is not None:
|
|
transformed_message[key] = message[value]
|
|
else:
|
|
LOG.debug(
|
|
f"Could not find value for property {value} in message: {message}"
|
|
)
|
|
|
|
# Map the role if necessary
|
|
if "role" in transformed_message:
|
|
transformed_message["role"] = self.prompter.roles.get(
|
|
transformed_message["role"], transformed_message["role"]
|
|
)
|
|
|
|
# Determine which keys in the original message were not mapped
|
|
mapped_values = set(self.prompter.message_property_mappings.values())
|
|
remaining_keys = set(message) - mapped_values
|
|
|
|
# Keep only the properties defined in the chat template
|
|
# and not already mapped
|
|
for key in self.prompter.chat_template_msg_variables:
|
|
if key in remaining_keys:
|
|
val = message.get(key)
|
|
if val is not None:
|
|
transformed_message[key] = val
|
|
|
|
return transformed_message
|
|
|
|
def get_images(self, prompt):
|
|
return prompt.get(self.images, None)
|
|
|
|
|
|
class StrategyLoader:
|
|
"""
|
|
Load chat template strategy based on configuration.
|
|
"""
|
|
|
|
def _get_strategy_cls(self):
|
|
return ChatTemplateStrategy
|
|
|
|
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
|
|
return {
|
|
"train_on_inputs": cfg.train_on_inputs,
|
|
"sequence_len": cfg.sequence_len,
|
|
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
|
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
|
}
|
|
|
|
def __call__(
|
|
self,
|
|
tokenizer,
|
|
cfg,
|
|
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
|
processor=None,
|
|
):
|
|
if ds_cfg is None:
|
|
dataset_config = {}
|
|
elif isinstance(ds_cfg, BaseModel):
|
|
dataset_config = ds_cfg.model_dump()
|
|
else:
|
|
dataset_config = ds_cfg
|
|
|
|
chat_template_string = get_chat_template_from_config(
|
|
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
|
)
|
|
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
|
|
|
prompter_params = {
|
|
"tokenizer": tokenizer,
|
|
"chat_template": chat_template_string,
|
|
"message_property_mappings": dataset_config.get(
|
|
"message_property_mappings", {}
|
|
),
|
|
"message_field_training": dataset_config.get(
|
|
"message_field_training", None
|
|
),
|
|
"message_field_training_detail": dataset_config.get(
|
|
"message_field_training_detail",
|
|
None,
|
|
),
|
|
"field_messages": dataset_config.get("field_messages", "messages"),
|
|
"roles": dataset_config.get("roles"),
|
|
"drop_system_message": dataset_config.get("drop_system_message", False),
|
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
|
"max_length": cfg.sequence_len + 1,
|
|
"processor": processor,
|
|
}
|
|
|
|
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
|
strategy_cls = self._get_strategy_cls()
|
|
|
|
strategy = strategy_cls(
|
|
ChatTemplatePrompter(**prompter_params),
|
|
tokenizer=tokenizer,
|
|
**strategy_params,
|
|
)
|
|
|
|
return strategy
|
|
|
|
|
|
load = StrategyLoader()
|