Files
axolotl/tests/utils/test_models.py
NJordan72 b194e17c28 feat: add config for optional parameters in a chat message (#2260)
* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-02-18 09:59:27 +07:00

195 lines
7.0 KiB
Python

"""Module for testing models utils file."""
from unittest.mock import MagicMock, patch
import pytest
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils.import_utils import is_torch_mps_available
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import ModelLoader, load_model
class TestModelsUtils:
"""Testing module for models utils."""
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
{
"base_model": "JackFram/llama-68m",
"model_type": "LlamaForCausalLM",
"tokenizer_type": "LlamaTokenizer",
"load_in_8bit": True,
"load_in_4bit": False,
"adapter": "lora",
"flash_attention": False,
"sample_packing": True,
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
)
def test_set_device_map_config(self):
# check device_map
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
self.model_loader.set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
else:
assert device_map in self.model_loader.model_kwargs["device_map"]
# check torch_dtype
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
cfg = DictDefault(
{
"s2_attention": True,
"sample_packing": True,
"base_model": "",
"model_type": "LlamaForCausalLM",
}
)
# Mock out call to HF hub
with patch(
"axolotl.utils.models.load_model_config"
) as mocked_load_model_config:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
)
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
@pytest.mark.parametrize("load_in_8bit", [True, False])
@pytest.mark.parametrize("load_in_4bit", [True, False])
@pytest.mark.parametrize("gptq", [True, False])
def test_set_quantization_config(
self,
adapter,
load_in_8bit,
load_in_4bit,
gptq,
):
# init cfg as args
self.cfg.load_in_8bit = load_in_8bit
self.cfg.load_in_4bit = load_in_4bit
self.cfg.gptq = gptq
self.cfg.adapter = adapter
self.model_loader.set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
)
elif load_in_8bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_8bit"]
elif load_in_4bit and self.cfg.adapter is not None:
assert self.model_loader.model_kwargs["load_in_4bit"]
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
self.cfg.adapter == "lora" and load_in_8bit
):
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)
def test_message_property_mapping(self):
"""Test message property mapping configuration validation"""
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
# Test legacy fields are mapped orrectly
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_field_content="content_field",
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test direct message_property_mapping works
dataset = SFTDataset(
path="test_path",
message_property_mappings={
"role": "custom_role",
"content": "custom_content",
},
)
assert dataset.message_property_mappings == {
"role": "custom_role",
"content": "custom_content",
}
# Test both legacy and new fields work when they match
dataset = SFTDataset(
path="test_path",
message_field_role="same_role",
message_property_mappings={"role": "same_role"},
)
assert dataset.message_property_mappings == {
"role": "same_role",
"content": "content",
}
# Test both legacy and new fields work when they don't overlap
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_property_mappings={"content": "content_field"},
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test no role or content provided
dataset = SFTDataset(
path="test_path",
)
assert dataset.message_property_mappings == {
"role": "role",
"content": "content",
}
# Test error when legacy and new fields conflict
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_role="legacy_role",
message_property_mappings={"role": "different_role"},
)
assert "Conflicting message role fields" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_content="legacy_content",
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)