Files
axolotl/src/axolotl/prompt_strategies/chat_template.py
NJordan72 b194e17c28 feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-02-18 09:59:27 +07:00

610 lines
22 KiB
Python

"""
HF Chat Templates prompt strategy
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Union
from pydantic import BaseModel
from transformers import ProcessorMixin
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
# Configure the logger
LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
class ChatTemplatePrompter(Prompter):
"""Prompter for HF chat templates"""
def __init__(
self,
tokenizer,
chat_template: str,
processor=None,
max_length=2048,
message_property_mappings: Optional[Dict[str, str]] = None,
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
field_messages: str = "messages",
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
# check if message_property_mappings is None or empty dict
if message_property_mappings is None or (not message_property_mappings):
message_property_mappings = {
"role": "role",
"content": "content",
}
if roles:
self.roles = {s: t for t, sources in roles.items() for s in sources}
else:
self.roles = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
"system": "system",
"tool": "tool",
}
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
chat_template, field_messages
)
self.message_property_mappings = message_property_mappings
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.field_messages = field_messages
self.tokenizer = tokenizer
self.processor: Optional[ProcessorMixin] = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
@property
def chat_template_msg_variables(self) -> Set[str]:
return self._chat_template_msg_variables
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
if self.processor:
if not callable(self.processor):
raise TypeError("Processor must be callable")
text = self.processor.apply_chat_template(
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
) -> List[int]:
tokenized_output = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
tokens = tokenized_output.tokens()
token_offsets = tokenized_output["offset_mapping"]
LOG.debug(f"Tokenizing text: {text}")
LOG.debug(f"Tokens: {tokens}")
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
for i in range(len(token_offsets) - 1):
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
# Ensure the last token's end offset is set correctly
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
LOG.debug(f"Token offsets: {token_offsets}")
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
result = [IGNORE_TOKEN_ID] * len(token_offsets)
# Adjust train_details to align with token boundaries
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
for idx, (start, end) in enumerate(token_offsets):
for detail in adjusted_train_details:
# Check if the token is completely within the detail's range
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
if detail["train"] or not mask_untrainable:
result[idx] = start
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
else:
LOG.debug(
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
)
elif start < detail["end_offset"] and end > detail["begin_offset"]:
# Token partially overlaps with detail, always mark as non-trainable
LOG.debug(
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
)
LOG.debug(f"Final result: {result}")
return result
def adjust_train_details(
self, train_details: List[Dict], token_offsets: List[tuple]
) -> List[Dict]:
adjusted_details = []
for detail in train_details:
begin_offset = detail["begin_offset"]
end_offset = detail["end_offset"]
# Find the first token that starts after or at the begin_offset
begin_token = next(
(
i
for i, (t_start, t_end) in enumerate(token_offsets)
if t_start >= begin_offset
),
len(token_offsets),
)
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
begin_token -= 1
# Find the last token that ends before or at the end_offset
end_token = next(
(
i
for i in range(len(token_offsets) - 1, -1, -1)
if token_offsets[i][1] <= end_offset
),
-1,
)
if (
end_token < len(token_offsets) - 1
and token_offsets[end_token + 1][0] < end_offset
):
end_token += 1
if begin_token <= end_token:
adjusted_begin = token_offsets[begin_token][0]
adjusted_end = token_offsets[end_token][1]
if adjusted_begin != begin_offset or adjusted_end != end_offset:
LOG.warning(
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
)
adjusted_details.append(
{
"begin_offset": adjusted_begin,
"end_offset": adjusted_end,
"train": detail["train"],
}
)
else:
LOG.warning(
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
)
return adjusted_details
def get_chat_template_msg_variables(
self, chat_template: str, field_messages: str
) -> Set[str]:
template_analyzer = JinjaTemplateAnalyzer(chat_template)
return template_analyzer.get_message_vars(field_messages)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def __init__(
self,
prompter: "ChatTemplatePrompter",
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.prompter: ChatTemplatePrompter = prompter
self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]
self.train_on_eos = train_on_eos
self.images = "images"
LOG.debug(
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
)
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
return True
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.prompter.field_messages]
)
except KeyError:
return False
def tokenize_prompt(self, prompt: dict[str, Any]):
"""
Public method that can handle either a single prompt or a batch of prompts.
"""
if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt)
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
# Process each prompt individually
for row in zip(*prompt.values()):
tokenized_prompt = self._tokenize_single_prompt(
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:
return {}
return dict(res)
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt["labels"] = labels
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get("role")
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train
LOG.debug(f"Should train: {should_train}")
turn_start_idx, turn_end_idx = self.find_turn(turns=turns, turn_idx=index)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
input_ids
):
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_first_eos_token(input_ids, start_idx=turn_end_idx)
if abs(eos_idx - turn_end_idx) <= 3: # Allow for some template padding
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
def find_first_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
def find_turn(self, turns: list[dict], turn_idx: int):
"""
Locate the starting and ending indices of the specified turn in a conversation.
"""
# pylint: disable=too-many-return-statements
if turn_idx >= len(turns):
raise ValueError(f"Turn index {turn_idx} out of range")
# mistral does not output message if it contains only system message
if (
turn_idx == 0
and turns[0].get("role") == "system"
and "mistral" in self.tokenizer.name_or_path.lower()
):
return -1, -1
empty_turn = {
"role": turns[turn_idx].get("role"),
"content": "[[dummy_message]]",
}
# Create conversation versions
turns_with_empty = turns[:turn_idx] + [empty_turn]
turns_with_content = turns[: turn_idx + 1]
# Generate the conversation up to the turn, with final turn replaced with dummy content
dummy_ids = self.prompter.build_prompt(turns_with_empty) # type: ignore
# Generate the conversation up to the turn, with final turn included
full_ids = self.prompter.build_prompt(turns_with_content) # type: ignore
if not full_ids or not dummy_ids:
LOG.warning(f"Empty template generated for turn {turn_idx}")
return -1, -1
# Find first difference (start of content)
start_idx = None
min_len = min(len(dummy_ids), len(full_ids))
for i in range(min_len):
if dummy_ids[i] != full_ids[i]:
start_idx = i
break
if start_idx is None:
LOG.warning(f"Could not find content start boundary for turn {turn_idx}")
return -1, -1
# Find last difference (end of content)
end_idx = None
for i in range(min_len):
dummy_pos = len(dummy_ids) - 1 - i
full_pos = len(full_ids) - 1 - i
if dummy_ids[dummy_pos] != full_ids[full_pos]:
end_idx = full_pos + 1 # Add one to include the last token when slice
break
if end_idx is None:
LOG.warning(f"Could not find content end boundary for turn {turn_idx}")
return -1, -1
if end_idx < start_idx:
LOG.warning(
f"Content end boundary is before start boundary for turn {turn_idx}"
)
return -1, -1
if end_idx == start_idx:
LOG.warning(
f"Content end boundary is the same as start boundary for turn {turn_idx}. This is likely an empty turn."
)
return -1, -1
LOG.debug(f"Content boundaries: {start_idx}, {end_idx}")
LOG.debug(
f"Content tokens: {self.tokenizer.convert_ids_to_tokens(full_ids[start_idx:end_idx])}"
)
return start_idx, end_idx
def get_conversation_thread(self, prompt):
turns = []
for message in prompt[self.prompter.field_messages]:
transformed_message = self.transform_message(message)
turn = {
**transformed_message,
"training": message.get(self.prompter.message_field_training),
"training_detail": message.get(
self.prompter.message_field_training_detail
),
}
turns.append(turn)
if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
return turns
def transform_message(self, message):
# Build the initial transformed message from the mappings
transformed_message = {}
for key, value in self.prompter.message_property_mappings.items():
if message.get(value) is not None:
transformed_message[key] = message[value]
else:
LOG.debug(
f"Could not find value for property {value} in message: {message}"
)
# Map the role if necessary
if "role" in transformed_message:
transformed_message["role"] = self.prompter.roles.get(
transformed_message["role"], transformed_message["role"]
)
# Determine which keys in the original message were not mapped
mapped_values = set(self.prompter.message_property_mappings.values())
remaining_keys = set(message) - mapped_values
# Keep only the properties defined in the chat template
# and not already mapped
for key in self.prompter.chat_template_msg_variables:
if key in remaining_keys:
val = message.get(key)
if val is not None:
transformed_message[key] = val
return transformed_message
def get_images(self, prompt):
return prompt.get(self.images, None)
class StrategyLoader:
"""
Load chat template strategy based on configuration.
"""
def _get_strategy_cls(self):
return ChatTemplateStrategy
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
def __call__(
self,
tokenizer,
cfg,
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
processor=None,
):
if ds_cfg is None:
dataset_config = {}
elif isinstance(ds_cfg, BaseModel):
dataset_config = ds_cfg.model_dump()
else:
dataset_config = ds_cfg
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_property_mappings": dataset_config.get(
"message_property_mappings", {}
),
"message_field_training": dataset_config.get(
"message_field_training", None
),
"message_field_training_detail": dataset_config.get(
"message_field_training_detail",
None,
),
"field_messages": dataset_config.get("field_messages", "messages"),
"roles": dataset_config.get("roles"),
"drop_system_message": dataset_config.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = self._get_strategy_params(cfg, dataset_config)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)
return strategy
load = StrategyLoader()