feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message * chore: cleanup * chore: fix nits and add light docs * docs: update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * feat: configurable message mappings, jinja template analyzer * chore: handle bradley terry * docs: update docs * refactor: change order of mappings, improve message transform * refactor: make chat awware of property mappings * chore: remove .python-version * chore: revert change * chore: add dataset validation to tests where appropriate * chore: add dataset validation to tests where appropriate * chore: clean up handling of ds_cfg * chore: recursively serialize config * make sure to use the return value from validate_config * DefaultDict pickle/unpickle fix * fix super call for override * refactor: message fields * chore: empty commit * tests: validate config before using * chore: add config validation to all e2e tests * chore: add unneeded logging * chore: add missed config validation * chore: pass field_messages to prompter * test: fix borked test * chore: remove uninteded file * chore: add deprecation warning and update chat_datasets script * chore: lint * refactor: message fields * feat: update axolotlinputconfig and test_models - add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py - remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes - update model_dump method in axolotlinputconfig to exclude none values - correct typo in test_models.py comment * feat: simplify dpodataset and ktodataset classes in config models removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets. * feat: improve readability and structure in dataset configuration models this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file. * feat: change log level from info to debug in chattemplatestrategy * feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy - Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor - Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor - Add `messages_array_name` property to `ChatTemplatePrompter` - Change `processor` type to Optional in `ChatTemplatePrompter` - Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt` - Remove `_messages` property from `ChatTemplateStrategy` - Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor - Remove `messages` getter and setter from `ChatTemplateStrategy` - Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread` - Remove condition to set `messages` field in `load` function * feat(tests/utils): ignore type check in load_model call in test_models.py * feat: improve type handling and test structure in chat templates - Add return type hint for `get_chat_template` function in `chat_templates.py` - Remove unnecessary assignment of `strategy.messages` in several test cases - Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py` - Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py` * feat(axolotl): enhance chat strategy with datasetconfig support This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include: - Importing Union from typing and BaseModel from pydantic. - Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader. - Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances. - Refactoring the prompter_params and strategy_params for better readability. - Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method. * feat: update message handling in btchattemplatestrategy * Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages" * Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages" * Add a new attribute "messages_array_name" to the `load` function parameters * Remove the conditional attribute assignment for "field_messages" in the `load` function * feat: add config validation in test_kd.py - Import `validate_config` from `axolotl.utils.config` - Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class * feat: enhance config validation and capabilities handling * Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` * Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)` * Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump * feat: update config validation in axolotl utils - Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals` - Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities` * feat: refactor strategyloader in chat_template.py - Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)` - Created a new function, `_get_strategy_cls()`, to obtain the strategy class - Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation * trigger CI * chore: revert dataset config changes for kto/dpo * subject: refactor: rename 'messages_array_name' to 'field_messages' Body: - Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py' - Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change - Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name - Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages' - Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name' * feat: refactor prompt strategies and update config models * Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py` * Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists * Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py` * Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model * chore: remove unused input * chore: remove redundant type ignore * fix: remove old configs and update examples * fix: type check * fix: remove loading old config in ChatMessage * fix: update faq with potential new undefinederror * fix: add debug if property mapped is not found * chore: improve explanation for unmapped properties * fix: update docs with new config * chore: add note for deprecation config and del old config from dict --------- Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> Co-authored-by: Wing Lian <wing@axolotl.ai> Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -142,10 +142,19 @@ datasets:
|
|||||||
|
|
||||||
# Key containing the messages (default: "messages")
|
# Key containing the messages (default: "messages")
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
# Key for role in each message (default: "role")
|
|
||||||
message_field_role: role
|
# Mapping of properties from the input dataset to the chat template.
|
||||||
# Key for content in each message (default: "content")
|
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||||
message_field_content: content
|
# If a property exists in the template but not in this mapping, the system will attempt
|
||||||
|
# to load it directly from the message using the property name as the key.
|
||||||
|
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
|
||||||
|
# while 'value' is loaded and used as 'content' in the chat template.
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
# ...
|
||||||
|
|
||||||
|
message_property_mappings:
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||||
roles:
|
roles:
|
||||||
|
|||||||
@@ -42,8 +42,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
|
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
@@ -52,8 +53,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
|
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
```
|
```
|
||||||
|
|
||||||
We recommend checking the below examples for other usecases.
|
We recommend checking the below examples for other usecases.
|
||||||
@@ -138,8 +140,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
chat_template: tokenizer_default
|
chat_template: tokenizer_default
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
roles_to_train: []
|
roles_to_train: []
|
||||||
train_on_eos: turn
|
train_on_eos: turn
|
||||||
message_field_training: train
|
message_field_training: train
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ A flow chart is as follows:
|
|||||||
|
|
||||||
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)
|
||||||
|
|
||||||
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a Github Discussion.
|
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
|
||||||
@@ -289,9 +289,10 @@ If your dataset format is different, here are the keys you should check (with th
|
|||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
...
|
...
|
||||||
field_messages: messages
|
field_messages: messages # this should point to the key containing the list of conversations
|
||||||
message_field_role: role
|
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
```
|
```
|
||||||
|
|
||||||
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
|
||||||
@@ -348,13 +349,14 @@ datasets:
|
|||||||
- path: A.jsonl
|
- path: A.jsonl
|
||||||
type: chat_template
|
type: chat_template
|
||||||
|
|
||||||
# step 1
|
# step 1
|
||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
|
|
||||||
# step 2
|
# step 2
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
|
|
||||||
roles:
|
roles:
|
||||||
assistant:
|
assistant:
|
||||||
@@ -365,8 +367,8 @@ datasets:
|
|||||||
- human
|
- human
|
||||||
- user
|
- user
|
||||||
|
|
||||||
# step 3
|
# step 3
|
||||||
roles_to_train: ["assistant"]
|
roles_to_train: ["assistant"]
|
||||||
train_on_eos: "turn"
|
train_on_eos: "turn"
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -23,3 +23,7 @@ description: Frequently asked questions
|
|||||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|
||||||
|
> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
|
||||||
|
|||||||
@@ -229,8 +229,9 @@ datasets:
|
|||||||
field_messages: "messages"
|
field_messages: "messages"
|
||||||
field_chosen: "chosen"
|
field_chosen: "chosen"
|
||||||
field_rejected: "rejected"
|
field_rejected: "rejected"
|
||||||
message_field_role: "role"
|
message_property_mappings:
|
||||||
message_field_content: "content"
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
user: ["user"]
|
user: ["user"]
|
||||||
assistant: ["assistant"]
|
assistant: ["assistant"]
|
||||||
|
|||||||
@@ -21,8 +21,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:20%]
|
split: train[:20%]
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
drop_system_message: true
|
drop_system_message: true
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/out
|
output_dir: ./outputs/out
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
drop_system_message: true
|
drop_system_message: true
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:20%]
|
split: train[:20%]
|
||||||
field_messages: conversations
|
field_messages: conversations
|
||||||
message_field_role: from
|
message_property_mappings:
|
||||||
message_field_content: value
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ datasets:
|
|||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
system:
|
system:
|
||||||
- system
|
- system
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ datasets:
|
|||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
user:
|
user:
|
||||||
- user
|
- user
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ datasets:
|
|||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
system:
|
system:
|
||||||
- system
|
- system
|
||||||
@@ -31,8 +32,9 @@ datasets:
|
|||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
system:
|
system:
|
||||||
- system
|
- system
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ datasets:
|
|||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
|
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
|
|||||||
@@ -14,8 +14,9 @@ datasets:
|
|||||||
- path: fozziethebeat/alpaca_messages_2k_test
|
- path: fozziethebeat/alpaca_messages_2k_test
|
||||||
type: chat_template
|
type: chat_template
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
user:
|
user:
|
||||||
- user
|
- user
|
||||||
|
|||||||
@@ -12,8 +12,9 @@ datasets:
|
|||||||
field_messages: conversation
|
field_messages: conversation
|
||||||
field_chosen: chosen
|
field_chosen: chosen
|
||||||
field_rejected: rejected
|
field_rejected: rejected
|
||||||
message_field_role: role
|
message_property_mappings:
|
||||||
message_field_content: content
|
role: role
|
||||||
|
content: content
|
||||||
roles:
|
roles:
|
||||||
system:
|
system:
|
||||||
- system
|
- system
|
||||||
|
|||||||
@@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
|
|||||||
ds_cfg["field_messages"] = field_messages
|
ds_cfg["field_messages"] = field_messages
|
||||||
|
|
||||||
message_fields = features[field_messages][0].keys()
|
message_fields = features[field_messages][0].keys()
|
||||||
message_field_role = None
|
|
||||||
|
message_property_mappings = {"role": None, "content": None}
|
||||||
for key in ["from", "role"]:
|
for key in ["from", "role"]:
|
||||||
if key in message_fields:
|
if key in message_fields:
|
||||||
message_field_role = key
|
message_property_mappings["role"] = key
|
||||||
break
|
break
|
||||||
if not message_field_role:
|
if not message_property_mappings["role"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'No role field found in messages: {", ".join(message_fields)}'
|
f'No role field found in messages: {", ".join(message_fields)}'
|
||||||
)
|
)
|
||||||
ds_cfg["message_field_role"] = message_field_role
|
|
||||||
|
|
||||||
message_field_content = None
|
|
||||||
for key in ["content", "text", "value"]:
|
for key in ["content", "text", "value"]:
|
||||||
if key in message_fields:
|
if key in message_fields:
|
||||||
message_field_content = key
|
message_property_mappings["content"] = key
|
||||||
break
|
break
|
||||||
if not message_field_content:
|
if not message_property_mappings["content"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'No content field found in messages: {", ".join(message_fields)}'
|
f'No content field found in messages: {", ".join(message_fields)}'
|
||||||
)
|
)
|
||||||
ds_cfg["message_field_content"] = message_field_content
|
ds_cfg["message_property_mappings"] = message_property_mappings
|
||||||
|
|
||||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||||
|
|
||||||
|
|||||||
@@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
|||||||
load_kwargs["ds_cfg"] = ds_cfg
|
load_kwargs["ds_cfg"] = ds_cfg
|
||||||
if "processor" in sig.parameters:
|
if "processor" in sig.parameters:
|
||||||
load_kwargs["processor"] = processor
|
load_kwargs["processor"] = processor
|
||||||
|
|
||||||
return func(tokenizer, cfg, **load_kwargs)
|
return func(tokenizer, cfg, **load_kwargs)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
return None
|
return None
|
||||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||||
raise exc
|
raise exc
|
||||||
return None
|
|
||||||
|
|||||||
@@ -34,15 +34,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
|
|
||||||
max_length = self.prompter.max_length
|
max_length = self.prompter.max_length
|
||||||
|
|
||||||
self.messages = "chosen_messages"
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt["messages"] = []
|
||||||
if prompt["system"]:
|
if prompt["system"]:
|
||||||
prompt[self.messages].append(
|
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||||
{"role": "system", "content": prompt["system"]}
|
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||||
)
|
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
|
||||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
|
||||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
|
||||||
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
chosen_tokenized = super()._tokenize_single_prompt(prompt)
|
||||||
|
|
||||||
if len(chosen_tokenized["input_ids"]) > max_length:
|
if len(chosen_tokenized["input_ids"]) > max_length:
|
||||||
@@ -55,17 +52,12 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
|
|||||||
:max_length
|
:max_length
|
||||||
]
|
]
|
||||||
|
|
||||||
self.messages = "rejected_messages"
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
prompt[self.messages] = []
|
prompt["messages"] = []
|
||||||
if prompt["system"]:
|
if prompt["system"]:
|
||||||
prompt[self.messages].append(
|
prompt["messages"].append({"role": "system", "content": prompt["system"]})
|
||||||
{"role": "system", "content": prompt["system"]}
|
prompt["messages"].append({"role": "user", "content": prompt["input"]})
|
||||||
)
|
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
|
||||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
|
||||||
prompt[self.messages].append(
|
|
||||||
{"role": "assistant", "content": prompt["rejected"]}
|
|
||||||
)
|
|
||||||
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
rejected_tokenized = super()._tokenize_single_prompt(prompt)
|
||||||
|
|
||||||
if len(rejected_tokenized["input_ids"]) > max_length:
|
if len(rejected_tokenized["input_ids"]) > max_length:
|
||||||
@@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_template_string,
|
"chat_template": chat_template_string,
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
"message_property_mappings": ds_cfg.get(
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
"message_property_mappings",
|
||||||
|
{
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
),
|
||||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training_detail": ds_cfg.get(
|
||||||
"message_field_training_detail", None
|
"message_field_training_detail", None
|
||||||
@@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||||
)
|
)
|
||||||
|
|
||||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
@@ -4,13 +4,16 @@ HF Chat Templates prompt strategy
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -23,16 +26,23 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
chat_template: str,
|
||||||
processor=None,
|
processor=None,
|
||||||
chat_template=None,
|
|
||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_field_role: str = "role",
|
message_property_mappings: Optional[Dict[str, str]] = None,
|
||||||
message_field_content: str = "content",
|
|
||||||
message_field_training: Optional[str] = None,
|
message_field_training: Optional[str] = None,
|
||||||
message_field_training_detail: Optional[str] = None,
|
message_field_training_detail: Optional[str] = None,
|
||||||
|
field_messages: str = "messages",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
|
# check if message_property_mappings is None or empty dict
|
||||||
|
if message_property_mappings is None or (not message_property_mappings):
|
||||||
|
message_property_mappings = {
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
}
|
||||||
|
|
||||||
if roles:
|
if roles:
|
||||||
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
self.roles = {s: t for t, sources in roles.items() for s in sources}
|
||||||
else:
|
else:
|
||||||
@@ -45,18 +55,28 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
"tool": "tool",
|
"tool": "tool",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.message_field_role = message_field_role
|
self._chat_template_msg_variables = self.get_chat_template_msg_variables(
|
||||||
self.message_field_content = message_field_content
|
chat_template, field_messages
|
||||||
|
)
|
||||||
|
self.message_property_mappings = message_property_mappings
|
||||||
self.message_field_training = message_field_training
|
self.message_field_training = message_field_training
|
||||||
self.message_field_training_detail = message_field_training_detail
|
self.message_field_training_detail = message_field_training_detail
|
||||||
|
self.field_messages = field_messages
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor: ProcessorMixin = processor
|
self.processor: Optional[ProcessorMixin] = processor
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.drop_system_message = drop_system_message
|
self.drop_system_message = drop_system_message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_template_msg_variables(self) -> Set[str]:
|
||||||
|
return self._chat_template_msg_variables
|
||||||
|
|
||||||
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
|
||||||
if self.processor:
|
if self.processor:
|
||||||
|
if not callable(self.processor):
|
||||||
|
raise TypeError("Processor must be callable")
|
||||||
|
|
||||||
text = self.processor.apply_chat_template(
|
text = self.processor.apply_chat_template(
|
||||||
conversation,
|
conversation,
|
||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
@@ -184,17 +204,21 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
|
|
||||||
return adjusted_details
|
return adjusted_details
|
||||||
|
|
||||||
|
def get_chat_template_msg_variables(
|
||||||
|
self, chat_template: str, field_messages: str
|
||||||
|
) -> Set[str]:
|
||||||
|
template_analyzer = JinjaTemplateAnalyzer(chat_template)
|
||||||
|
return template_analyzer.get_message_vars(field_messages)
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
Tokenizing strategy for instruction-based prompts.
|
Tokenizing strategy for instruction-based prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_messages = "messages"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompter: ChatTemplatePrompter,
|
prompter: "ChatTemplatePrompter",
|
||||||
tokenizer,
|
tokenizer,
|
||||||
train_on_inputs,
|
train_on_inputs,
|
||||||
sequence_len,
|
sequence_len,
|
||||||
@@ -202,6 +226,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
train_on_eos=None,
|
train_on_eos=None,
|
||||||
):
|
):
|
||||||
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
|
self.prompter: ChatTemplatePrompter = prompter
|
||||||
|
|
||||||
self.roles_to_train = []
|
self.roles_to_train = []
|
||||||
if roles_to_train:
|
if roles_to_train:
|
||||||
@@ -213,13 +238,9 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self.train_on_eos = train_on_eos
|
self.train_on_eos = train_on_eos
|
||||||
self.images = "images"
|
self.images = "images"
|
||||||
|
|
||||||
@property
|
LOG.debug(
|
||||||
def messages(self):
|
f"The chat template uses the following properites on the message: {self.prompter.chat_template_msg_variables}"
|
||||||
return self._messages
|
)
|
||||||
|
|
||||||
@messages.setter
|
|
||||||
def messages(self, messages):
|
|
||||||
self._messages = messages
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_batched(self) -> bool:
|
def supports_batched(self) -> bool:
|
||||||
@@ -229,7 +250,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||||
try:
|
try:
|
||||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
return all(isinstance(v, list) for v in prompt.values()) and all(
|
||||||
isinstance(v, list) for v in prompt[self.messages]
|
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
@@ -464,30 +485,17 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
turns = []
|
turns = []
|
||||||
optional_keys = [
|
for message in prompt[self.prompter.field_messages]:
|
||||||
"tool_calls", # tool that 'assistant' calls
|
transformed_message = self.transform_message(message)
|
||||||
"name", # name of tool given by 'tool'
|
|
||||||
"tool_call_id", # mistral/mixtral requires this
|
|
||||||
]
|
|
||||||
for message in prompt[self.messages]:
|
|
||||||
turn = {
|
turn = {
|
||||||
"role": self.prompter.roles[message[self.prompter.message_field_role]],
|
**transformed_message,
|
||||||
"training": message.get(self.prompter.message_field_training),
|
"training": message.get(self.prompter.message_field_training),
|
||||||
"training_detail": message.get(
|
"training_detail": message.get(
|
||||||
self.prompter.message_field_training_detail
|
self.prompter.message_field_training_detail
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# do not add content if None as it may conflict with some templates due to tools
|
|
||||||
content = message.get(self.prompter.message_field_content, None)
|
|
||||||
if content is not None:
|
|
||||||
turn["content"] = content
|
|
||||||
|
|
||||||
for key in optional_keys:
|
|
||||||
value = message.get(key, None)
|
|
||||||
if value is not None:
|
|
||||||
turn[key] = value
|
|
||||||
|
|
||||||
turns.append(turn)
|
turns.append(turn)
|
||||||
|
|
||||||
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
if self.prompter.drop_system_message and turns[0]["role"] == "system":
|
||||||
@@ -495,6 +503,37 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
def transform_message(self, message):
|
||||||
|
# Build the initial transformed message from the mappings
|
||||||
|
transformed_message = {}
|
||||||
|
for key, value in self.prompter.message_property_mappings.items():
|
||||||
|
if message.get(value) is not None:
|
||||||
|
transformed_message[key] = message[value]
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"Could not find value for property {value} in message: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map the role if necessary
|
||||||
|
if "role" in transformed_message:
|
||||||
|
transformed_message["role"] = self.prompter.roles.get(
|
||||||
|
transformed_message["role"], transformed_message["role"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine which keys in the original message were not mapped
|
||||||
|
mapped_values = set(self.prompter.message_property_mappings.values())
|
||||||
|
remaining_keys = set(message) - mapped_values
|
||||||
|
|
||||||
|
# Keep only the properties defined in the chat template
|
||||||
|
# and not already mapped
|
||||||
|
for key in self.prompter.chat_template_msg_variables:
|
||||||
|
if key in remaining_keys:
|
||||||
|
val = message.get(key)
|
||||||
|
if val is not None:
|
||||||
|
transformed_message[key] = val
|
||||||
|
|
||||||
|
return transformed_message
|
||||||
|
|
||||||
def get_images(self, prompt):
|
def get_images(self, prompt):
|
||||||
return prompt.get(self.images, None)
|
return prompt.get(self.images, None)
|
||||||
|
|
||||||
@@ -516,33 +555,46 @@ class StrategyLoader:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
|
self,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
ds_cfg: Optional[Union[Dict[str, Any], DatasetConfig]] = None,
|
||||||
|
processor=None,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
if ds_cfg is None:
|
||||||
ds_cfg = ds_cfg or {}
|
dataset_config = {}
|
||||||
|
elif isinstance(ds_cfg, BaseModel):
|
||||||
|
dataset_config = ds_cfg.model_dump()
|
||||||
|
else:
|
||||||
|
dataset_config = ds_cfg
|
||||||
|
|
||||||
chat_template_string = get_chat_template_from_config(
|
chat_template_string = get_chat_template_from_config(
|
||||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
cfg=cfg, ds_cfg=dataset_config, tokenizer=tokenizer
|
||||||
)
|
)
|
||||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||||
|
|
||||||
prompter_params = {
|
prompter_params = {
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": tokenizer,
|
||||||
"chat_template": chat_template_string,
|
"chat_template": chat_template_string,
|
||||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
"message_property_mappings": dataset_config.get(
|
||||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
"message_property_mappings", {}
|
||||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
),
|
||||||
"message_field_training_detail": ds_cfg.get(
|
"message_field_training": dataset_config.get(
|
||||||
|
"message_field_training", None
|
||||||
|
),
|
||||||
|
"message_field_training_detail": dataset_config.get(
|
||||||
"message_field_training_detail",
|
"message_field_training_detail",
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
"roles": ds_cfg.get("roles"),
|
"field_messages": dataset_config.get("field_messages", "messages"),
|
||||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
"roles": dataset_config.get("roles"),
|
||||||
|
"drop_system_message": dataset_config.get("drop_system_message", False),
|
||||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||||
"max_length": cfg.sequence_len + 1,
|
"max_length": cfg.sequence_len + 1,
|
||||||
"processor": processor,
|
"processor": processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
strategy_params = self._get_strategy_params(cfg, ds_cfg)
|
strategy_params = self._get_strategy_params(cfg, dataset_config)
|
||||||
strategy_cls = self._get_strategy_cls()
|
strategy_cls = self._get_strategy_cls()
|
||||||
|
|
||||||
strategy = strategy_cls(
|
strategy = strategy_cls(
|
||||||
@@ -551,9 +603,6 @@ class StrategyLoader:
|
|||||||
**strategy_params,
|
**strategy_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,20 +3,28 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
cfg, dataset_idx=0, **kwargs
|
cfg, dataset_idx=0, **kwargs
|
||||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
ds_cfg = cfg["datasets"][dataset_idx]
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
ds_cfg = handle_legacy_message_fields_logic(ds_cfg)
|
||||||
|
|
||||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||||
cfg=cfg, ds_cfg=ds_cfg
|
cfg=cfg, ds_cfg=ds_cfg
|
||||||
)
|
)
|
||||||
field_messages = ds_cfg.get("field_messages", "messages")
|
field_messages = ds_cfg.get("field_messages", "messages")
|
||||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
field_message_role = ds_cfg.get("message_field_role", "role")
|
message_property_mappings = ds_cfg.get(
|
||||||
field_message_content = ds_cfg.get("message_field_content", "content")
|
"message_property_mappings",
|
||||||
|
{
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
)
|
||||||
role_map_inv = ds_cfg.get(
|
role_map_inv = ds_cfg.get(
|
||||||
"roles",
|
"roles",
|
||||||
{
|
{
|
||||||
@@ -40,18 +48,18 @@ def default(
|
|||||||
messages = sample[field_messages]
|
messages = sample[field_messages]
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": role_map[m[field_message_role]],
|
"role": role_map[m[message_property_mappings["role"]]],
|
||||||
"content": m[field_message_content],
|
"content": m[message_property_mappings["content"]],
|
||||||
}
|
}
|
||||||
for m in messages
|
for m in messages
|
||||||
]
|
]
|
||||||
chosen = {
|
chosen = {
|
||||||
"role": role_map[sample[field_chosen][field_message_role]],
|
"role": role_map[sample[field_chosen][message_property_mappings["role"]]],
|
||||||
"content": sample[field_chosen][field_message_content],
|
"content": sample[field_chosen][message_property_mappings["content"]],
|
||||||
}
|
}
|
||||||
rejected = {
|
rejected = {
|
||||||
"role": role_map[sample[field_rejected][field_message_role]],
|
"role": role_map[sample[field_rejected][message_property_mappings["role"]]],
|
||||||
"content": sample[field_rejected][field_message_content],
|
"content": sample[field_rejected][message_property_mappings["content"]],
|
||||||
}
|
}
|
||||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||||
|
|
||||||
|
|||||||
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
318
src/axolotl/prompt_strategies/jinja_template_analyzer.py
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
"""Module for inspect jinja templates for the variables they use"""
|
||||||
|
from typing import Dict, Optional, Set, TypedDict, Union
|
||||||
|
|
||||||
|
from jinja2 import Environment, meta, nodes
|
||||||
|
|
||||||
|
|
||||||
|
class JinjaTemplateAnalysis(TypedDict):
|
||||||
|
"""
|
||||||
|
Represents the detailed analysis of a Jinja template variable.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
accessed_properties (Set[str]): A set of properties accessed from the variable
|
||||||
|
(e.g., `foo.bar` results in 'bar' being accessed for 'foo').
|
||||||
|
accessed_indices (Set[Union[int, float]]): A set of indices accessed from the variable.
|
||||||
|
is_iterated (bool): Indicates if the variable is used as an iteration source in a `for` loop.
|
||||||
|
is_conditional (bool): Indicates if the variable is referenced within a conditional statement (e.g., an `if` block).
|
||||||
|
iteration_source (Optional[str]): The name of the variable being iterated over, if applicable.
|
||||||
|
iteration_target (Optional[Union[str, list[str]]]): The loop target(s) assigned in the iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
accessed_properties: Set[str]
|
||||||
|
accessed_indices: Set[Union[int, float]]
|
||||||
|
is_iterated: bool
|
||||||
|
is_conditional: bool
|
||||||
|
iteration_source: Optional[str]
|
||||||
|
iteration_target: Optional[Union[str, list[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
class JinjaTemplateAnalyzer:
|
||||||
|
"""
|
||||||
|
Analyzes Jinja templates to extract information about variable usage,
|
||||||
|
including accessed properties, iteration, and conditional references.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
env (jinja2.Environment): The Jinja2 environment used for parsing templates.
|
||||||
|
property_access (Dict[str, Set[str]]): Tracks accessed properties for variables.
|
||||||
|
iteration_targets (Dict[str, str]): Maps iteration target variables to their sources.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
get_template_variables(template: str) -> Dict[str, Set[str]]:
|
||||||
|
Parse a Jinja template and return a mapping of variables to their accessed properties.
|
||||||
|
|
||||||
|
analyze_template(template: str) -> Dict[str, JinjaTemplateAnalysis]:
|
||||||
|
Perform a detailed analysis of the template, including variable usage,
|
||||||
|
iteration, and conditional references.
|
||||||
|
|
||||||
|
Private Methods:
|
||||||
|
_visit_node(node) -> None:
|
||||||
|
Recursively visit AST nodes to detect attribute access and iteration targets.
|
||||||
|
|
||||||
|
_get_base_name(node) -> Optional[str]:
|
||||||
|
Extract the base variable name from a node.
|
||||||
|
|
||||||
|
_get_target_name(node) -> Optional[Union[str, list[str]]]:
|
||||||
|
Extract the target name(s) from a `For` node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, template: str):
|
||||||
|
self.env: Environment = Environment(autoescape=True)
|
||||||
|
self.property_access: Dict[str, Set[str]] = {}
|
||||||
|
self.iteration_targets: Dict[str, Union[str, list[str]]] = {}
|
||||||
|
self.index_access: Dict[str, Set[Union[int, float]]] = {}
|
||||||
|
self.ast: nodes.Node = self.env.parse(template)
|
||||||
|
self.template: str = template
|
||||||
|
self.variable_assignments: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def _visit_node(self, node) -> None:
|
||||||
|
"""Recursively visit AST nodes to find attribute access."""
|
||||||
|
# Handle attribute access (dot notation)
|
||||||
|
if isinstance(node, nodes.Getattr):
|
||||||
|
base_name = self._get_base_name(node.node)
|
||||||
|
if base_name:
|
||||||
|
self.property_access.setdefault(base_name, set()).add(node.attr)
|
||||||
|
|
||||||
|
# Handle dictionary access (subscript notation)
|
||||||
|
elif isinstance(node, nodes.Getitem):
|
||||||
|
base_name = self._get_base_name(node.node)
|
||||||
|
if base_name and isinstance(node.arg, nodes.Const):
|
||||||
|
value = node.arg.value
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
self.index_access.setdefault(base_name, set()).add(value)
|
||||||
|
else:
|
||||||
|
self.property_access.setdefault(base_name, set()).add(value)
|
||||||
|
|
||||||
|
elif isinstance(node, nodes.Test) and node.name == "defined":
|
||||||
|
base_name = self._get_base_name(node.node)
|
||||||
|
if base_name:
|
||||||
|
if isinstance(node.node, nodes.Getattr):
|
||||||
|
self.property_access.setdefault(base_name, set()).add(
|
||||||
|
node.node.attr
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle loop variables
|
||||||
|
elif isinstance(node, nodes.For):
|
||||||
|
iter_name = self._get_base_name(node.iter)
|
||||||
|
target_name = self._get_target_name(node.target)
|
||||||
|
if iter_name and target_name:
|
||||||
|
self.iteration_targets[target_name] = iter_name
|
||||||
|
self.property_access.setdefault(iter_name, set())
|
||||||
|
|
||||||
|
elif isinstance(node, nodes.Assign):
|
||||||
|
target_name = self._get_target_name(node.target)
|
||||||
|
source_name = self._get_base_name(node.node)
|
||||||
|
if target_name and source_name:
|
||||||
|
self.variable_assignments[target_name] = source_name
|
||||||
|
|
||||||
|
elif isinstance(node, nodes.Filter):
|
||||||
|
if node.name == "selectattr":
|
||||||
|
target = self._get_base_name(node.node)
|
||||||
|
if target:
|
||||||
|
self.variable_assignments[f"filtered_{target}"] = target
|
||||||
|
|
||||||
|
for child in node.iter_child_nodes():
|
||||||
|
self._visit_node(child)
|
||||||
|
|
||||||
|
def _get_target_name(self, node) -> Optional[str]:
|
||||||
|
"""Get the target variable name from a For node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: A Jinja AST node representing either a Name or Tuple node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- str: For simple variable targets (e.g., "item" in "for item in items")
|
||||||
|
- None: If the node type is not recognized or is a tuple
|
||||||
|
"""
|
||||||
|
if isinstance(node, nodes.Name):
|
||||||
|
return node.name
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_target_names(self, node) -> list[str]:
|
||||||
|
"""Get all target variable names from a For node, including tuple unpacking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: A Jinja AST node representing either a Name or Tuple node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of target variable names
|
||||||
|
"""
|
||||||
|
if isinstance(node, nodes.Name):
|
||||||
|
return [node.name]
|
||||||
|
|
||||||
|
if isinstance(node, nodes.Tuple):
|
||||||
|
names = []
|
||||||
|
for n in node.items:
|
||||||
|
if isinstance(n, nodes.Name):
|
||||||
|
names.append(n.name)
|
||||||
|
return names
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_base_name(self, node) -> Optional[str]:
|
||||||
|
"""Get the base variable name from a node."""
|
||||||
|
if isinstance(node, nodes.Name):
|
||||||
|
return node.name
|
||||||
|
|
||||||
|
if isinstance(node, nodes.Getattr):
|
||||||
|
return self._get_base_name(node.node)
|
||||||
|
|
||||||
|
if isinstance(node, nodes.Getitem):
|
||||||
|
return self._get_base_name(node.node)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_template_variables(self) -> Dict[str, Set[str]]:
|
||||||
|
"""
|
||||||
|
Parse a Jinja template and return both variables and their accessed properties.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
template (str): The Jinja template string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Set[str]]: Dictionary mapping variable names to sets of accessed properties
|
||||||
|
"""
|
||||||
|
# Parse the template
|
||||||
|
ast = self.env.parse(self.template)
|
||||||
|
|
||||||
|
# Get all undeclared variables
|
||||||
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
|
||||||
|
# Reset property access tracking
|
||||||
|
self.property_access = {}
|
||||||
|
|
||||||
|
# Visit all nodes to find property access
|
||||||
|
self._visit_node(ast)
|
||||||
|
|
||||||
|
# Create result dictionary
|
||||||
|
result: Dict[str, Set[str]] = {var: set() for var in variables}
|
||||||
|
# Merge in any discovered sub-properties
|
||||||
|
for var, props in self.property_access.items():
|
||||||
|
if var not in result:
|
||||||
|
result[var] = set()
|
||||||
|
result[var].update(props)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def analyze_template(self) -> Dict[str, JinjaTemplateAnalysis]:
|
||||||
|
"""
|
||||||
|
Provide a detailed analysis of template variables and their usage.
|
||||||
|
"""
|
||||||
|
variables = self.get_template_variables()
|
||||||
|
self.iteration_targets = {}
|
||||||
|
|
||||||
|
analysis: Dict[str, JinjaTemplateAnalysis] = {
|
||||||
|
var: JinjaTemplateAnalysis(
|
||||||
|
accessed_properties=props,
|
||||||
|
accessed_indices=set(),
|
||||||
|
is_iterated=False,
|
||||||
|
is_conditional=False,
|
||||||
|
iteration_source=None,
|
||||||
|
iteration_target=None,
|
||||||
|
)
|
||||||
|
for var, props in variables.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
for var, indices in self.index_access.items():
|
||||||
|
if var in analysis:
|
||||||
|
analysis[var]["accessed_indices"] = indices
|
||||||
|
|
||||||
|
def visit_node(node):
|
||||||
|
if isinstance(node, nodes.If):
|
||||||
|
|
||||||
|
def find_test_vars(test_node):
|
||||||
|
if isinstance(test_node, nodes.Name):
|
||||||
|
if test_node.name in analysis:
|
||||||
|
analysis[test_node.name]["is_conditional"] = True
|
||||||
|
for child in test_node.iter_child_nodes():
|
||||||
|
find_test_vars(child)
|
||||||
|
|
||||||
|
find_test_vars(node.test)
|
||||||
|
|
||||||
|
if isinstance(node, nodes.For):
|
||||||
|
iter_target = self._get_base_name(node.iter)
|
||||||
|
target_name = self._get_target_name(node.target)
|
||||||
|
if iter_target in analysis:
|
||||||
|
analysis[iter_target]["is_iterated"] = True
|
||||||
|
if target_name:
|
||||||
|
analysis[iter_target]["iteration_target"] = target_name
|
||||||
|
if isinstance(target_name, str) and target_name not in analysis:
|
||||||
|
analysis[target_name] = {
|
||||||
|
"accessed_properties": set(),
|
||||||
|
"is_iterated": False,
|
||||||
|
"is_conditional": False,
|
||||||
|
"iteration_source": iter_target,
|
||||||
|
"iteration_target": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
for child in node.iter_child_nodes():
|
||||||
|
visit_node(child)
|
||||||
|
|
||||||
|
visit_node(self.ast)
|
||||||
|
return analysis
|
||||||
|
|
||||||
|
def get_downstream_properties(self, start_var: str) -> Dict[str, Set[str]]:
|
||||||
|
"""
|
||||||
|
Get all properties accessed on a variable and its downstream assignments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_var: The starting variable to trace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping variable names to their accessed properties
|
||||||
|
"""
|
||||||
|
visited = set()
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
def trace_variable(var_name: str):
|
||||||
|
if var_name in visited:
|
||||||
|
return
|
||||||
|
visited.add(var_name)
|
||||||
|
|
||||||
|
# Get direct properties
|
||||||
|
if var_name in self.property_access:
|
||||||
|
properties[var_name] = self.property_access[var_name]
|
||||||
|
|
||||||
|
# Get properties from iteration targets
|
||||||
|
if var_name in self.iteration_targets:
|
||||||
|
target = self.iteration_targets[var_name]
|
||||||
|
if isinstance(target, str):
|
||||||
|
trace_variable(target)
|
||||||
|
elif isinstance(target, list):
|
||||||
|
for t in target:
|
||||||
|
trace_variable(t)
|
||||||
|
|
||||||
|
# Follow assignments
|
||||||
|
for target, source in self.variable_assignments.items():
|
||||||
|
if source == var_name:
|
||||||
|
trace_variable(target)
|
||||||
|
|
||||||
|
# Check for array slicing
|
||||||
|
analysis = self.analyze_template()
|
||||||
|
if var_name in analysis:
|
||||||
|
var_info = analysis[var_name]
|
||||||
|
if var_info["accessed_indices"]:
|
||||||
|
# If this variable is sliced, follow the resulting assignment
|
||||||
|
slice_result = f"{var_name}_slice"
|
||||||
|
if slice_result in self.property_access:
|
||||||
|
trace_variable(slice_result)
|
||||||
|
|
||||||
|
trace_variable(start_var)
|
||||||
|
return properties
|
||||||
|
|
||||||
|
def get_message_vars(self, field_messages: str = "messages") -> Set[str]:
|
||||||
|
"""
|
||||||
|
Get all properties accessed on messages and derived variables.
|
||||||
|
"""
|
||||||
|
all_properties = self.get_downstream_properties(field_messages)
|
||||||
|
|
||||||
|
# Combine all properties from all related variables
|
||||||
|
combined_properties = set()
|
||||||
|
for properties in all_properties.values():
|
||||||
|
combined_properties.update(properties)
|
||||||
|
|
||||||
|
# Also include properties from the message iteration variable
|
||||||
|
analysis = self.analyze_template()
|
||||||
|
if "message" in analysis:
|
||||||
|
combined_properties.update(analysis["message"]["accessed_properties"])
|
||||||
|
|
||||||
|
return combined_properties
|
||||||
@@ -51,8 +51,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
ds_cfg = ds_cfg or {}
|
ds_cfg = ds_cfg or {}
|
||||||
|
|
||||||
field_messages = ds_cfg.get("field_messages")
|
field_messages = ds_cfg.get("field_messages")
|
||||||
message_field_role = ds_cfg.get("message_field_role")
|
message_property_mappings = ds_cfg.get("message_property_mappings")
|
||||||
message_field_content = ds_cfg.get("message_field_content")
|
message_field_role = (
|
||||||
|
message_property_mappings.get("role") if message_property_mappings else None
|
||||||
|
)
|
||||||
|
message_field_content = (
|
||||||
|
message_property_mappings.get("content") if message_property_mappings else None
|
||||||
|
)
|
||||||
message_field_training = ds_cfg.get("message_field_training")
|
message_field_training = ds_cfg.get("message_field_training")
|
||||||
|
|
||||||
builder_kwargs = {}
|
builder_kwargs = {}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def get_chat_template(
|
|||||||
user_choice: str,
|
user_choice: str,
|
||||||
jinja_template: Optional[str] = None,
|
jinja_template: Optional[str] = None,
|
||||||
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
||||||
):
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ def get_chat_template(
|
|||||||
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
|
||||||
f"Please add a chat_template in tokenizer config"
|
f"Please add a chat_template in tokenizer config"
|
||||||
)
|
)
|
||||||
return tokenizer.chat_template
|
return tokenizer.chat_template # type: ignore
|
||||||
|
|
||||||
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -78,7 +78,7 @@ def get_chat_template(
|
|||||||
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
|
||||||
)
|
)
|
||||||
if tokenizer.chat_template:
|
if tokenizer.chat_template:
|
||||||
return tokenizer.chat_template
|
return tokenizer.chat_template # type: ignore
|
||||||
|
|
||||||
user_choice = user_choice[
|
user_choice = user_choice[
|
||||||
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from axolotl.utils.config.models.input.v0_4_1 import (
|
|||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import DPODataset, KTODataset, SFTDataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
@@ -258,7 +259,7 @@ def validate_config(
|
|||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
capabilities: Optional[dict] = None,
|
capabilities: Optional[dict] = None,
|
||||||
env_capabilities: Optional[dict] = None,
|
env_capabilities: Optional[dict] = None,
|
||||||
):
|
) -> DictDefault:
|
||||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||||
AxolotlInputConfig = AxolotlInputConfigBase
|
AxolotlInputConfig = AxolotlInputConfigBase
|
||||||
|
|
||||||
@@ -268,6 +269,16 @@ def validate_config(
|
|||||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||||
) = merge_input_args()
|
) = merge_input_args()
|
||||||
|
|
||||||
|
# Convert datasets to proper format if needed
|
||||||
|
if cfg.get("datasets"):
|
||||||
|
for idx, ds_cfg in enumerate(cfg["datasets"]):
|
||||||
|
if cfg.get("rl") == "dpo" and not isinstance(ds_cfg, DPODataset):
|
||||||
|
cfg["datasets"][idx] = DPODataset(**ds_cfg)
|
||||||
|
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
|
||||||
|
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))
|
||||||
|
elif not isinstance(ds_cfg, SFTDataset):
|
||||||
|
cfg["datasets"][idx] = SFTDataset(**dict(ds_cfg))
|
||||||
|
|
||||||
if capabilities or env_capabilities:
|
if capabilities or env_capabilities:
|
||||||
if (capabilities and env_capabilities is None) or (
|
if (capabilities and env_capabilities is None) or (
|
||||||
env_capabilities and capabilities is None
|
env_capabilities and capabilities is None
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
StringConstraints,
|
StringConstraints,
|
||||||
conlist,
|
conlist,
|
||||||
|
field_serializer,
|
||||||
field_validator,
|
field_validator,
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
@@ -186,8 +187,13 @@ class SFTDataset(BaseModel):
|
|||||||
field_human: Optional[str] = None
|
field_human: Optional[str] = None
|
||||||
field_model: Optional[str] = None
|
field_model: Optional[str] = None
|
||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
message_field_role: Optional[str] = None
|
message_field_role: Optional[
|
||||||
message_field_content: Optional[str] = None
|
str
|
||||||
|
] = None # deprecated, use message_property_mappings
|
||||||
|
message_field_content: Optional[
|
||||||
|
str
|
||||||
|
] = None # deprecated, use message_property_mappings
|
||||||
|
message_property_mappings: Optional[Dict[str, str]] = None
|
||||||
message_field_training: Optional[str] = None
|
message_field_training: Optional[str] = None
|
||||||
message_field_training_detail: Optional[str] = None
|
message_field_training_detail: Optional[str] = None
|
||||||
logprobs_field: Optional[str] = None
|
logprobs_field: Optional[str] = None
|
||||||
@@ -199,9 +205,18 @@ class SFTDataset(BaseModel):
|
|||||||
trust_remote_code: Optional[bool] = False
|
trust_remote_code: Optional[bool] = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def handle_legacy_message_fields(cls, data):
|
||||||
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||||
|
return handle_legacy_message_fields_logic(data)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_chat_template_config(cls, data):
|
def check_chat_template_config(cls, data):
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
# Set chat_template to tokenizer_default if not set
|
# Set chat_template to tokenizer_default if not set
|
||||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
@@ -241,6 +256,7 @@ class DPODataset(BaseModel):
|
|||||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
field_messages: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class StepwiseSupervisedDataset(BaseModel):
|
class StepwiseSupervisedDataset(BaseModel):
|
||||||
@@ -277,6 +293,9 @@ class KTODataset(BaseModel):
|
|||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
"""LoftQ configuration subset"""
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
@@ -680,17 +699,15 @@ class AxolotlInputConfig(
|
|||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
dpo_use_logits_to_keep: Optional[bool] = None
|
dpo_use_logits_to_keep: Optional[bool] = None
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(DatasetConfig, min_length=1)] = None # type: ignore
|
||||||
shuffle_merged_datasets: Optional[bool] = True
|
shuffle_merged_datasets: Optional[bool] = True
|
||||||
dataset_prepared_path: Optional[str] = None
|
dataset_prepared_path: Optional[str] = None
|
||||||
dataset_shard_num: Optional[int] = None
|
dataset_shard_num: Optional[int] = None
|
||||||
dataset_shard_idx: Optional[int] = None
|
dataset_shard_idx: Optional[int] = None
|
||||||
skip_prepare_dataset: Optional[bool] = False
|
skip_prepare_dataset: Optional[bool] = False
|
||||||
|
|
||||||
pretraining_dataset: Optional[ # type: ignore
|
pretraining_dataset: Optional[conlist(Union[PretrainingDataset, SFTDataset], min_length=1)] = Field( # type: ignore
|
||||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
|
||||||
] = Field(
|
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
@@ -895,10 +912,15 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def deprecate_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
for _, ds_cfg in enumerate(datasets):
|
for _, ds_cfg in enumerate(datasets):
|
||||||
if not ds_cfg.get("type"):
|
# Handle both dict and pydantic model cases
|
||||||
|
ds_type = (
|
||||||
|
ds_cfg.get("type")
|
||||||
|
if isinstance(ds_cfg, dict)
|
||||||
|
else getattr(ds_cfg, "type", None)
|
||||||
|
)
|
||||||
|
if not ds_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ds_type = ds_cfg["type"]
|
|
||||||
# skip if it's a dict (for custom user instruction prompt)
|
# skip if it's a dict (for custom user instruction prompt)
|
||||||
if isinstance(ds_type, dict):
|
if isinstance(ds_type, dict):
|
||||||
continue
|
continue
|
||||||
@@ -910,6 +932,14 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
@field_serializer("datasets")
|
||||||
|
def datasets_serializer(
|
||||||
|
self, ds_configs: Optional[List[DatasetConfig]]
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
if ds_configs:
|
||||||
|
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
|
||||||
|
return None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_batch_size_fields(cls, data):
|
def check_batch_size_fields(cls, data):
|
||||||
@@ -1762,3 +1792,77 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
else:
|
else:
|
||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||||
|
|
||||||
|
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||||
|
- message_field_role: Mapped to the role field
|
||||||
|
- message_field_content: Mapped to the content field
|
||||||
|
|
||||||
|
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||||
|
message_property_mappings:
|
||||||
|
role: source_role_field
|
||||||
|
content: source_content_field
|
||||||
|
additional_field: source_field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated dictionary with message field mappings consolidated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are conflicts between legacy and new mappings
|
||||||
|
"""
|
||||||
|
data = data.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
|
if data.get("message_property_mappings") is None:
|
||||||
|
data["message_property_mappings"] = {}
|
||||||
|
|
||||||
|
# Check for conflicts and handle role
|
||||||
|
if "message_field_role" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"role" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||||
|
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||||
|
|
||||||
|
del data["message_field_role"]
|
||||||
|
elif "role" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["role"] = "role"
|
||||||
|
|
||||||
|
# Check for conflicts and handle content
|
||||||
|
if "message_field_content" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"content" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["message_field_content"]
|
||||||
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["content"] = "content"
|
||||||
|
|
||||||
|
return data
|
||||||
|
|||||||
@@ -180,6 +180,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||||
tokenizer_name = cfg.tokenizer_config
|
tokenizer_name = cfg.tokenizer_config
|
||||||
|
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5(
|
md5(
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -13,3 +13,26 @@ class DictDefault(Dict):
|
|||||||
|
|
||||||
def __or__(self, other):
|
def __or__(self, other):
|
||||||
return DictDefault(super().__ror__(other))
|
return DictDefault(super().__ror__(other))
|
||||||
|
|
||||||
|
def __setitem__(self, name, value):
|
||||||
|
# workaround for pickle/unpickle issues and __frozen not being available
|
||||||
|
try:
|
||||||
|
isFrozen = hasattr( # pylint: disable=invalid-name
|
||||||
|
self, "__frozen"
|
||||||
|
) and object.__getattribute__(self, "__frozen")
|
||||||
|
except AttributeError:
|
||||||
|
isFrozen = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
if isFrozen and name not in super().keys():
|
||||||
|
raise KeyError(name)
|
||||||
|
super(Dict, self).__setitem__(name, value) # pylint: disable=bad-super-call
|
||||||
|
try:
|
||||||
|
p = object.__getattribute__(self, "__parent")
|
||||||
|
key = object.__getattribute__(self, "__key")
|
||||||
|
except AttributeError:
|
||||||
|
p = None
|
||||||
|
key = None
|
||||||
|
if p is not None:
|
||||||
|
p[key] = self
|
||||||
|
object.__delattr__(self, "__parent")
|
||||||
|
object.__delattr__(self, "__key")
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
|
|||||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||||
cfg = DictDefault(kd_min_cfg)
|
cfg = DictDefault(kd_min_cfg)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
@@ -109,6 +110,7 @@ class TestKnowledgeDistillation:
|
|||||||
| kd_min_cfg
|
| kd_min_cfg
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = validate_config(cfg)
|
||||||
prepare_plugins(cfg)
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
from ..utils import check_model_output_exists, check_tensorboard
|
||||||
@@ -76,7 +76,9 @@ class TestFAXentropyLlama:
|
|||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
@@ -73,6 +73,8 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_preference_datasets
|
from axolotl.common.datasets import load_preference_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -108,6 +110,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -153,6 +157,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -198,6 +204,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -242,6 +250,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -289,6 +299,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -353,6 +365,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
@@ -56,6 +56,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -65,6 +65,8 @@ class TestFalcon(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -118,6 +120,8 @@ class TestFalcon(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -157,6 +161,8 @@ class TestFalcon(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
@@ -56,6 +56,8 @@ class TestLlama:
|
|||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -99,6 +101,8 @@ class TestLlama:
|
|||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -138,6 +142,8 @@ class TestLlama:
|
|||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard
|
from .utils import check_model_output_exists, check_tensorboard
|
||||||
@@ -69,6 +69,8 @@ class TestPretrainLlama:
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -62,6 +62,8 @@ class TestLlamaVision(unittest.TestCase):
|
|||||||
"bf16": True,
|
"bf16": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -59,6 +59,8 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"max_steps": 20,
|
"max_steps": 20,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import pytest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -59,6 +59,8 @@ class TestMamba(unittest.TestCase):
|
|||||||
"save_safetensors": False,
|
"save_safetensors": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -63,6 +63,8 @@ class TestMistral(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -106,6 +108,8 @@ class TestMistral(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -69,6 +69,8 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -123,6 +125,8 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -180,6 +184,8 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -233,6 +239,8 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
@@ -281,6 +289,8 @@ class TestMixtral(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||||
@@ -59,6 +59,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -103,6 +105,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -139,6 +143,8 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_tensorboard, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
@@ -59,6 +59,8 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
cfg.bf16 = True
|
cfg.bf16 = True
|
||||||
else:
|
else:
|
||||||
cfg.fp16 = True
|
cfg.fp16 = True
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, with_temp_dir
|
from .utils import check_model_output_exists, with_temp_dir
|
||||||
@@ -61,6 +61,7 @@ class TestPhi(unittest.TestCase):
|
|||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -40,8 +40,10 @@ class TestE2eQwen:
|
|||||||
"field_messages": "conversation",
|
"field_messages": "conversation",
|
||||||
"field_chosen": "chosen",
|
"field_chosen": "chosen",
|
||||||
"field_rejected": "rejected",
|
"field_rejected": "rejected",
|
||||||
"message_field_role": "role",
|
"message_property_mappings": {
|
||||||
"message_field_content": "content",
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
"roles": {
|
"roles": {
|
||||||
"system": ["system"],
|
"system": ["system"],
|
||||||
"user": ["user"],
|
"user": ["user"],
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import unittest
|
|||||||
from axolotl.cli.args import TrainerCliArgs
|
from axolotl.cli.args import TrainerCliArgs
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, validate_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||||
@@ -66,6 +66,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
|||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cfg = validate_config(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from datasets import Dataset
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||||
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
|
||||||
|
|
||||||
|
|
||||||
@@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
|
|||||||
modified_template = template.replace(old_date_logic, new_date_logic)
|
modified_template = template.replace(old_date_logic, new_date_logic)
|
||||||
|
|
||||||
return modified_template
|
return modified_template
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
|
||||||
|
def fixture_chat_template_jinja_with_optional_fields() -> str:
|
||||||
|
return """{% for message in messages %}
|
||||||
|
{{'<|im_start|>'}}{{ message['role'] }}
|
||||||
|
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
|
||||||
|
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
|
||||||
|
{{ message['content'] }}{{'<|im_end|>'}}
|
||||||
|
{% endfor %}"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="basic_jinja_template_analyzer")
|
||||||
|
def basic_jinja_template_analyzer():
|
||||||
|
return JinjaTemplateAnalyzer(
|
||||||
|
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
|
||||||
|
' + message['content'] + '<|end|>
|
||||||
|
'}}{% elif message['role'] == 'user' %}{{'<|user|>
|
||||||
|
' + message['content'] + '<|end|>
|
||||||
|
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
|
||||||
|
' + message['content'] + '<|end|>
|
||||||
|
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
|
||||||
|
' }}{% else %}{{ eos_token }}{% endif %}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="mistral_jinja_template_analyzer")
|
||||||
|
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
|
||||||
|
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
"chat_template": "llama3",
|
"chat_template": "llama3",
|
||||||
"message_field_role": "role",
|
"message_field_role": "role",
|
||||||
"message_field_content": "content",
|
"message_field_content": "content",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
"roles": {
|
"roles": {
|
||||||
"user": ["user"],
|
"user": ["user"],
|
||||||
"assistant": ["assistant"],
|
"assistant": ["assistant"],
|
||||||
@@ -74,8 +78,10 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=get_chat_template("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_role="role",
|
message_property_mappings={
|
||||||
message_field_content="content",
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
roles={
|
roles={
|
||||||
"user": ["user"],
|
"user": ["user"],
|
||||||
"assistant": ["assistant"],
|
"assistant": ["assistant"],
|
||||||
@@ -86,7 +92,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
|
||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -114,8 +120,10 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
phi35_tokenizer,
|
phi35_tokenizer,
|
||||||
chat_template=get_chat_template("phi_35"),
|
chat_template=get_chat_template("phi_35"),
|
||||||
message_field_role="role",
|
message_property_mappings={
|
||||||
message_field_content="content",
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
roles={
|
roles={
|
||||||
"user": ["user"],
|
"user": ["user"],
|
||||||
"assistant": ["assistant"],
|
"assistant": ["assistant"],
|
||||||
@@ -126,7 +134,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
|
||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -170,9 +178,11 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=get_chat_template("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_role="role",
|
|
||||||
message_field_content="content",
|
|
||||||
message_field_training="training",
|
message_field_training="training",
|
||||||
|
message_property_mappings={
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
roles={
|
roles={
|
||||||
"user": ["user"],
|
"user": ["user"],
|
||||||
"assistant": ["assistant"],
|
"assistant": ["assistant"],
|
||||||
@@ -185,7 +195,7 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
|
||||||
prompt_tokens = strategy.prompter.build_prompt(
|
prompt_tokens = strategy.prompter.build_prompt(
|
||||||
assistant_dataset[0]["messages"], False
|
assistant_dataset[0]["messages"], False
|
||||||
)
|
)
|
||||||
@@ -230,8 +240,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=get_chat_template("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_role="from",
|
message_property_mappings={
|
||||||
message_field_content="value",
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -239,7 +252,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["gpt"],
|
roles_to_train=["gpt"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -287,8 +300,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=get_chat_template("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_role="from",
|
message_property_mappings={
|
||||||
message_field_content="value",
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -296,7 +312,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["human"],
|
roles_to_train=["human"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -344,8 +360,11 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
chat_template=get_chat_template("llama3"),
|
chat_template=get_chat_template("llama3"),
|
||||||
message_field_role="from",
|
message_property_mappings={
|
||||||
message_field_content="value",
|
"role": "from",
|
||||||
|
"content": "value",
|
||||||
|
},
|
||||||
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -353,7 +372,7 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["system", "human"],
|
roles_to_train=["system", "human"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -417,8 +436,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="role",
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
message_field_content="content",
|
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -486,8 +504,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="role",
|
message_property_mappings={"role": "role", "content": "content"},
|
||||||
message_field_content="content",
|
|
||||||
),
|
),
|
||||||
tokenizer=llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ tests for chat_template prompt strategy
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -123,15 +122,15 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -180,15 +179,15 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -241,20 +240,15 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant", "human"],
|
roles_to_train=["assistant", "human"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
|
||||||
labels = res["labels"]
|
|
||||||
input_ids = res["input_ids"]
|
|
||||||
|
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -307,15 +301,15 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=True,
|
train_on_inputs=True,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["human", "assistant"],
|
roles_to_train=["human", "assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -360,8 +354,8 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -369,7 +363,7 @@ class TestChatTemplateConfigurations:
|
|||||||
roles_to_train=[],
|
roles_to_train=[],
|
||||||
train_on_eos="none", # Add this line
|
train_on_eos="none", # Add this line
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
|
|
||||||
@@ -400,8 +394,8 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -409,7 +403,7 @@ class TestChatTemplateConfigurations:
|
|||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
train_on_eos="all",
|
train_on_eos="all",
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
@@ -446,8 +440,8 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -455,7 +449,6 @@ class TestChatTemplateConfigurations:
|
|||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
train_on_eos="turn",
|
train_on_eos="turn",
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
turns = strategy.get_conversation_thread(basic_dataset[0])
|
turns = strategy.get_conversation_thread(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
@@ -526,8 +519,8 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -535,7 +528,7 @@ class TestChatTemplateConfigurations:
|
|||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
train_on_eos="last",
|
train_on_eos="last",
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
@@ -578,8 +571,8 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template=get_chat_template(
|
chat_template=get_chat_template(
|
||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -587,7 +580,7 @@ class TestChatTemplateConfigurations:
|
|||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
train_on_eos="none",
|
train_on_eos="none",
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
labels = res["labels"]
|
labels = res["labels"]
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
@@ -624,15 +617,15 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
drop_system_message=True,
|
drop_system_message=True,
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
field_messages="conversations",
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
sequence_len=512,
|
sequence_len=512,
|
||||||
roles_to_train=["assistant"],
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "conversations"
|
|
||||||
res = strategy.tokenize_prompt(basic_dataset[0])
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
@@ -668,8 +661,7 @@ class TestChatTemplateConfigurations:
|
|||||||
chat_template, jinja_template=chat_template_jinja
|
chat_template, jinja_template=chat_template_jinja
|
||||||
),
|
),
|
||||||
roles=custom_roles,
|
roles=custom_roles,
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -741,8 +733,7 @@ class TestChatTemplateConfigurations:
|
|||||||
),
|
),
|
||||||
message_field_training="train",
|
message_field_training="train",
|
||||||
message_field_training_detail="train_detail",
|
message_field_training_detail="train_detail",
|
||||||
message_field_role="from",
|
message_property_mappings={"role": "from", "content": "value"},
|
||||||
message_field_content="value",
|
|
||||||
),
|
),
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
train_on_inputs=False,
|
train_on_inputs=False,
|
||||||
@@ -911,6 +902,64 @@ class TestChatTemplateConfigurations:
|
|||||||
LOG.debug(f"Final labels: {labels}")
|
LOG.debug(f"Final labels: {labels}")
|
||||||
LOG.debug(f"Final input_ids: {input_ids}")
|
LOG.debug(f"Final input_ids: {input_ids}")
|
||||||
|
|
||||||
|
def test_get_chat_template_variables(
|
||||||
|
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
):
|
||||||
|
LOG.info("Testing get_chat_template_variables")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
|
||||||
unittest.main()
|
tokenizer, chat_template, chat_template_jinja, eos_token, request
|
||||||
|
)
|
||||||
|
|
||||||
|
prompter = ChatTemplatePrompter(
|
||||||
|
actual_tokenizer,
|
||||||
|
chat_template=get_chat_template(
|
||||||
|
chat_template, jinja_template=actual_jinja_template
|
||||||
|
),
|
||||||
|
message_property_mappings={"from": "role", "value": "content"},
|
||||||
|
)
|
||||||
|
|
||||||
|
variables = prompter.get_chat_template_msg_variables(
|
||||||
|
actual_jinja_template
|
||||||
|
if actual_jinja_template
|
||||||
|
else actual_tokenizer.get_chat_template(),
|
||||||
|
"messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
if chat_template == "llama3":
|
||||||
|
assert variables == {"role", "content"}, (
|
||||||
|
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||||
|
f"Got: {variables}\n"
|
||||||
|
f"Chat template: {actual_jinja_template}"
|
||||||
|
)
|
||||||
|
elif chat_template == "chatml":
|
||||||
|
assert variables == {"role", "content"}, (
|
||||||
|
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||||
|
f"Got: {variables}\n"
|
||||||
|
f"Chat template: {actual_jinja_template}"
|
||||||
|
)
|
||||||
|
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
|
||||||
|
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
|
||||||
|
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
|
||||||
|
f"Got: {variables}\n"
|
||||||
|
f"Chat template: {actual_jinja_template}"
|
||||||
|
)
|
||||||
|
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
|
||||||
|
assert variables == {"role", "content"}, (
|
||||||
|
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||||
|
f"Got: {variables}\n"
|
||||||
|
f"Chat template: {actual_jinja_template}"
|
||||||
|
)
|
||||||
|
elif chat_template == "phi_35":
|
||||||
|
assert variables == {"role", "content"}, (
|
||||||
|
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
|
||||||
|
f"Got: {variables}\n"
|
||||||
|
f"Chat template: {actual_jinja_template}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
|
||||||
|
)
|
||||||
|
|||||||
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
159
tests/prompt_strategies/test_jinja_template_analyzer.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
tests for jinja_template_analyzer
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class TestJinjaTemplateAnalyzer:
|
||||||
|
"""
|
||||||
|
tests for jinja_template_analyzer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
|
||||||
|
"""Test that all top-level variables are correctly extracted."""
|
||||||
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
|
|
||||||
|
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||||
|
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
|
||||||
|
assert set(variables.keys()) == expected_vars
|
||||||
|
|
||||||
|
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
|
||||||
|
"""Test that all top-level variables are correctly extracted."""
|
||||||
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
|
|
||||||
|
variables = mistral_jinja_template_analyzer.get_template_variables()
|
||||||
|
expected_vars = {
|
||||||
|
"messages",
|
||||||
|
"content",
|
||||||
|
"eos_token",
|
||||||
|
"message",
|
||||||
|
"tools",
|
||||||
|
"system_message",
|
||||||
|
"loop_messages",
|
||||||
|
"ns",
|
||||||
|
"tool_call",
|
||||||
|
"tool",
|
||||||
|
"loop",
|
||||||
|
"bos_token",
|
||||||
|
"raise_exception",
|
||||||
|
}
|
||||||
|
assert set(variables.keys()) == expected_vars
|
||||||
|
message_vars = variables["message"]
|
||||||
|
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
|
||||||
|
|
||||||
|
def test_message_property_access(self, basic_jinja_template_analyzer):
|
||||||
|
"""Test that properties accessed on 'message' variable are correctly identified."""
|
||||||
|
LOG.info("Testing message property access")
|
||||||
|
|
||||||
|
variables = basic_jinja_template_analyzer.get_template_variables()
|
||||||
|
assert "messages" in variables
|
||||||
|
assert "message" in variables
|
||||||
|
assert "role" in variables["message"]
|
||||||
|
assert "content" in variables["message"]
|
||||||
|
|
||||||
|
def test_detailed_analysis(self, basic_jinja_template_analyzer):
|
||||||
|
"""Test the detailed analysis of variable usage."""
|
||||||
|
LOG.info("Testing detailed analysis")
|
||||||
|
|
||||||
|
analysis = basic_jinja_template_analyzer.analyze_template()
|
||||||
|
|
||||||
|
assert analysis["messages"]["is_iterated"] is True
|
||||||
|
assert "role" in analysis["message"]["accessed_properties"]
|
||||||
|
assert "content" in analysis["message"]["accessed_properties"]
|
||||||
|
|
||||||
|
assert analysis["add_generation_prompt"]["is_conditional"] is True
|
||||||
|
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
|
||||||
|
|
||||||
|
assert not analysis["eos_token"]["is_iterated"]
|
||||||
|
assert len(analysis["eos_token"]["accessed_properties"]) == 0
|
||||||
|
|
||||||
|
def test_nested_property_access(self):
|
||||||
|
"""Test handling of nested property access."""
|
||||||
|
LOG.info("Testing nested property access")
|
||||||
|
|
||||||
|
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
|
||||||
|
analyzer = JinjaTemplateAnalyzer(template)
|
||||||
|
variables = analyzer.get_template_variables()
|
||||||
|
|
||||||
|
assert "user" in variables
|
||||||
|
assert "profile" in variables["user"]
|
||||||
|
assert "settings" in variables["user"]
|
||||||
|
|
||||||
|
def test_loop_variable_handling(self):
|
||||||
|
"""Test handling of loop variables and their properties."""
|
||||||
|
LOG.info("Testing loop variable handling")
|
||||||
|
|
||||||
|
template = """
|
||||||
|
{% for item in items %}
|
||||||
|
{{ item.name }}
|
||||||
|
{% for subitem in item.subitems %}
|
||||||
|
{{ subitem.value }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
"""
|
||||||
|
analyzer = JinjaTemplateAnalyzer(template)
|
||||||
|
analysis = analyzer.analyze_template()
|
||||||
|
|
||||||
|
assert analysis["items"]["is_iterated"]
|
||||||
|
assert "name" in analysis["item"]["accessed_properties"]
|
||||||
|
assert "subitems" in analysis["item"]["accessed_properties"]
|
||||||
|
|
||||||
|
def test_conditional_variable_usage(self):
|
||||||
|
"""Test detection of variables used in conditional statements."""
|
||||||
|
LOG.info("Testing conditional variable usage")
|
||||||
|
|
||||||
|
template = """
|
||||||
|
{% if user.is_admin and config.debug_mode %}
|
||||||
|
{{ debug_info }}
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
analyzer = JinjaTemplateAnalyzer(template)
|
||||||
|
analysis = analyzer.analyze_template()
|
||||||
|
|
||||||
|
assert analysis["user"]["is_conditional"]
|
||||||
|
assert analysis["config"]["is_conditional"]
|
||||||
|
assert "is_admin" in analysis["user"]["accessed_properties"]
|
||||||
|
assert "debug_mode" in analysis["config"]["accessed_properties"]
|
||||||
|
|
||||||
|
def test_complex_expressions(self):
|
||||||
|
"""Test handling of complex expressions and filters."""
|
||||||
|
LOG.info("Testing complex expressions and filters")
|
||||||
|
|
||||||
|
template = """
|
||||||
|
{{ user.name | upper }}
|
||||||
|
{{ messages | length > 0 and messages[0].content }}
|
||||||
|
{{ data['key'].nested['value'] }}
|
||||||
|
"""
|
||||||
|
analyzer = JinjaTemplateAnalyzer(template)
|
||||||
|
variables = analyzer.get_template_variables()
|
||||||
|
|
||||||
|
assert "user" in variables
|
||||||
|
assert "name" in variables["user"]
|
||||||
|
assert "messages" in variables
|
||||||
|
assert "content" in variables["messages"]
|
||||||
|
assert "data" in variables
|
||||||
|
|
||||||
|
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
|
||||||
|
"""Test that the basic message variables are correctly identified."""
|
||||||
|
LOG.info("Testing basic message variables")
|
||||||
|
|
||||||
|
variables = basic_jinja_template_analyzer.get_message_vars()
|
||||||
|
assert variables == {"role", "content"}
|
||||||
|
|
||||||
|
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
|
||||||
|
"""Test that the mixtral message variables are correctly identified."""
|
||||||
|
LOG.info("Testing mixtral message variables")
|
||||||
|
|
||||||
|
variables = mistral_jinja_template_analyzer.get_message_vars()
|
||||||
|
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
@@ -302,3 +302,22 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_message_property_mappings(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
"message_property_mappings": {
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ class TestModelsUtils:
|
|||||||
mocked_load_model_config.return_value = {}
|
mocked_load_model_config.return_value = {}
|
||||||
with pytest.raises(ValueError) as exc:
|
with pytest.raises(ValueError) as exc:
|
||||||
# Should error before hitting tokenizer, so we pass in an empty str
|
# Should error before hitting tokenizer, so we pass in an empty str
|
||||||
load_model(cfg, tokenizer="")
|
load_model(cfg, tokenizer="") # type: ignore
|
||||||
assert (
|
assert (
|
||||||
"shifted-sparse attention does not currently support sample packing"
|
"shifted-sparse attention does not currently support sample packing"
|
||||||
in str(exc.value)
|
in str(exc.value)
|
||||||
@@ -116,3 +116,79 @@ class TestModelsUtils:
|
|||||||
assert self.model_loader.model_kwargs.get(
|
assert self.model_loader.model_kwargs.get(
|
||||||
"quantization_config", BitsAndBytesConfig
|
"quantization_config", BitsAndBytesConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_message_property_mapping(self):
|
||||||
|
"""Test message property mapping configuration validation"""
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||||
|
|
||||||
|
# Test legacy fields are mapped orrectly
|
||||||
|
dataset = SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_field_role="role_field",
|
||||||
|
message_field_content="content_field",
|
||||||
|
)
|
||||||
|
assert dataset.message_property_mappings == {
|
||||||
|
"role": "role_field",
|
||||||
|
"content": "content_field",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test direct message_property_mapping works
|
||||||
|
dataset = SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_property_mappings={
|
||||||
|
"role": "custom_role",
|
||||||
|
"content": "custom_content",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert dataset.message_property_mappings == {
|
||||||
|
"role": "custom_role",
|
||||||
|
"content": "custom_content",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test both legacy and new fields work when they match
|
||||||
|
dataset = SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_field_role="same_role",
|
||||||
|
message_property_mappings={"role": "same_role"},
|
||||||
|
)
|
||||||
|
assert dataset.message_property_mappings == {
|
||||||
|
"role": "same_role",
|
||||||
|
"content": "content",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test both legacy and new fields work when they don't overlap
|
||||||
|
dataset = SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_field_role="role_field",
|
||||||
|
message_property_mappings={"content": "content_field"},
|
||||||
|
)
|
||||||
|
assert dataset.message_property_mappings == {
|
||||||
|
"role": "role_field",
|
||||||
|
"content": "content_field",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test no role or content provided
|
||||||
|
dataset = SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
)
|
||||||
|
assert dataset.message_property_mappings == {
|
||||||
|
"role": "role",
|
||||||
|
"content": "content",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test error when legacy and new fields conflict
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_field_role="legacy_role",
|
||||||
|
message_property_mappings={"role": "different_role"},
|
||||||
|
)
|
||||||
|
assert "Conflicting message role fields" in str(exc_info.value)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
SFTDataset(
|
||||||
|
path="test_path",
|
||||||
|
message_field_content="legacy_content",
|
||||||
|
message_property_mappings={"content": "different_content"},
|
||||||
|
)
|
||||||
|
assert "Conflicting message content fields" in str(exc_info.value)
|
||||||
|
|||||||
Reference in New Issue
Block a user