feat: add config for optional parameters in a chat message (#2260)

* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
NJordan72
2025-02-17 21:59:27 -05:00
committed by GitHub
parent 3aac3b1da9
commit b194e17c28
51 changed files with 1190 additions and 230 deletions

View File

@@ -9,7 +9,7 @@ from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
@@ -79,6 +79,7 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -109,6 +110,7 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -76,7 +76,9 @@ class TestFAXentropyLlama:
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -73,6 +73,8 @@ class TestReLoraLlama(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_preference_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -108,6 +110,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -153,6 +157,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -198,6 +204,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -242,6 +250,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -289,6 +299,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
@@ -353,6 +365,8 @@ class TestDPOLlamaLora(unittest.TestCase):
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -56,6 +56,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -118,6 +120,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -157,6 +161,8 @@ class TestFalcon(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ from e2e.utils import check_model_output_exists
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
@@ -56,6 +56,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -99,6 +101,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -138,6 +142,8 @@ class TestLlama:
"save_safetensors": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
@@ -69,6 +69,8 @@ class TestPretrainLlama:
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -62,6 +62,8 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,6 +59,8 @@ class TestLoraLlama(unittest.TestCase):
"max_steps": 20,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -59,6 +59,8 @@ class TestMamba(unittest.TestCase):
"save_safetensors": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,8 @@ class TestMistral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -106,6 +108,8 @@ class TestMistral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -69,6 +69,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -123,6 +125,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -180,6 +184,8 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -233,6 +239,8 @@ class TestMixtral(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
@@ -281,6 +289,8 @@ class TestMixtral(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
@@ -59,6 +59,8 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -103,6 +105,8 @@ class TestCustomOptimizers(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -139,6 +143,8 @@ class TestCustomOptimizers(unittest.TestCase):
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
@@ -59,6 +59,8 @@ class TestPackedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
@@ -61,6 +61,7 @@ class TestPhi(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -40,8 +40,10 @@ class TestE2eQwen:
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"system": ["system"],
"user": ["user"],

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -66,6 +66,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -7,6 +7,7 @@ from datasets import Dataset
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
from axolotl.utils.chat_templates import _CHAT_TEMPLATES
@@ -174,3 +175,32 @@ def fixture_llama3_2_vision_with_hardcoded_date() -> str:
modified_template = template.replace(old_date_logic, new_date_logic)
return modified_template
@pytest.fixture(name="chat_template_jinja_with_optional_fields")
def fixture_chat_template_jinja_with_optional_fields() -> str:
return """{% for message in messages %}
{{'<|im_start|>'}}{{ message['role'] }}
{% if message['thoughts'] is defined %}[Thoughts: {{ message['thoughts'] }}]{% endif %}
{% if message['tool_calls'] is defined %}[Tool: {{ message['tool_calls'][0]['type'] }}]{% endif %}
{{ message['content'] }}{{'<|im_end|>'}}
{% endfor %}"""
@pytest.fixture(name="basic_jinja_template_analyzer")
def basic_jinja_template_analyzer():
return JinjaTemplateAnalyzer(
"""{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'user' %}{{'<|user|>
' + message['content'] + '<|end|>
'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>
' + message['content'] + '<|end|>
'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>
' }}{% else %}{{ eos_token }}{% endif %}"""
)
@pytest.fixture(name="mistral_jinja_template_analyzer")
def mistral_jinja_template_analyzer(mistralv03_tokenizer_chat_template_jinja):
return JinjaTemplateAnalyzer(mistralv03_tokenizer_chat_template_jinja)

View File

@@ -38,6 +38,10 @@ class TestAssistantChatTemplateLlama3:
"chat_template": "llama3",
"message_field_role": "role",
"message_field_content": "content",
"message_property_mappings": {
"role": "role",
"content": "content",
},
"roles": {
"user": ["user"],
"assistant": ["assistant"],
@@ -74,8 +78,10 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -86,7 +92,7 @@ class TestAssistantChatTemplateLlama3:
train_on_inputs=False,
sequence_len=512,
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
# fmt: off
@@ -114,8 +120,10 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
phi35_tokenizer,
chat_template=get_chat_template("phi_35"),
message_field_role="role",
message_field_content="content",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -126,7 +134,7 @@ class TestAssistantChatTemplateLlama3:
train_on_inputs=False,
sequence_len=512,
)
strategy.messages = "messages"
res = strategy.tokenize_prompt(assistant_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -170,9 +178,11 @@ class TestAssistantChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="role",
message_field_content="content",
message_field_training="training",
message_property_mappings={
"role": "role",
"content": "content",
},
roles={
"user": ["user"],
"assistant": ["assistant"],
@@ -185,7 +195,7 @@ class TestAssistantChatTemplateLlama3:
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "messages"
prompt_tokens = strategy.prompter.build_prompt(
assistant_dataset[0]["messages"], False
)
@@ -230,8 +240,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -239,7 +252,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["gpt"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -287,8 +300,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -296,7 +312,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(sharegpt_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -344,8 +360,11 @@ class TestSharegptChatTemplateLlama3:
ChatTemplatePrompter(
llama3_tokenizer,
chat_template=get_chat_template("llama3"),
message_field_role="from",
message_field_content="value",
message_property_mappings={
"role": "from",
"content": "value",
},
field_messages="conversations",
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -353,7 +372,7 @@ class TestSharegptChatTemplateLlama3:
sequence_len=512,
roles_to_train=["system", "human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
labels = res["labels"]
@@ -417,8 +436,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,
@@ -486,8 +504,7 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
chat_template=get_chat_template(
"jinja", jinja_template=llama3_2_vision_chat_template_jinja
),
message_field_role="role",
message_field_content="content",
message_property_mappings={"role": "role", "content": "content"},
),
tokenizer=llama3_tokenizer,
train_on_inputs=False,

View File

@@ -3,7 +3,6 @@ tests for chat_template prompt strategy
"""
import logging
import unittest
from copy import deepcopy
import pytest
@@ -123,15 +122,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -180,15 +179,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -241,20 +240,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant", "human"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -307,15 +301,15 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=True,
sequence_len=512,
roles_to_train=["human", "assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -360,8 +354,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -369,7 +363,7 @@ class TestChatTemplateConfigurations:
roles_to_train=[],
train_on_eos="none", # Add this line
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
@@ -400,8 +394,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -409,7 +403,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="all",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -446,8 +440,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -455,7 +449,6 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="turn",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
turns = strategy.get_conversation_thread(basic_dataset[0])
labels = res["labels"]
@@ -526,8 +519,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -535,7 +528,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="last",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -578,8 +571,8 @@ class TestChatTemplateConfigurations:
chat_template=get_chat_template(
chat_template, jinja_template=chat_template_jinja
),
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -587,7 +580,7 @@ class TestChatTemplateConfigurations:
roles_to_train=["assistant"],
train_on_eos="none",
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
labels = res["labels"]
input_ids = res["input_ids"]
@@ -624,15 +617,15 @@ class TestChatTemplateConfigurations:
chat_template, jinja_template=chat_template_jinja
),
drop_system_message=True,
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
field_messages="conversations",
),
tokenizer=tokenizer,
train_on_inputs=False,
sequence_len=512,
roles_to_train=["assistant"],
)
strategy.messages = "conversations"
res = strategy.tokenize_prompt(basic_dataset[0])
input_ids = res["input_ids"]
@@ -668,8 +661,7 @@ class TestChatTemplateConfigurations:
chat_template, jinja_template=chat_template_jinja
),
roles=custom_roles,
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -741,8 +733,7 @@ class TestChatTemplateConfigurations:
),
message_field_training="train",
message_field_training_detail="train_detail",
message_field_role="from",
message_field_content="value",
message_property_mappings={"role": "from", "content": "value"},
),
tokenizer=tokenizer,
train_on_inputs=False,
@@ -911,6 +902,64 @@ class TestChatTemplateConfigurations:
LOG.debug(f"Final labels: {labels}")
LOG.debug(f"Final input_ids: {input_ids}")
def test_get_chat_template_variables(
self, tokenizer, chat_template, chat_template_jinja, eos_token, request
):
LOG.info("Testing get_chat_template_variables")
if __name__ == "__main__":
unittest.main()
actual_tokenizer, actual_jinja_template = self.setup_tokenizer(
tokenizer, chat_template, chat_template_jinja, eos_token, request
)
prompter = ChatTemplatePrompter(
actual_tokenizer,
chat_template=get_chat_template(
chat_template, jinja_template=actual_jinja_template
),
message_property_mappings={"from": "role", "value": "content"},
)
variables = prompter.get_chat_template_msg_variables(
actual_jinja_template
if actual_jinja_template
else actual_tokenizer.get_chat_template(),
"messages",
)
if chat_template == "llama3":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "chatml":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "mistralv03_tokenizer":
assert variables == {"role", "content", "tool_call_id", "tool_calls"}, (
f"Expected variables: {'role', 'content', 'tool_call_id', 'tool_calls'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "jinja" and tokenizer == "gemma2_tokenizer":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
elif chat_template == "phi_35":
assert variables == {"role", "content"}, (
f"Expected variables: {'role', 'content'} from {tokenizer}/{chat_template}\n"
f"Got: {variables}\n"
f"Chat template: {actual_jinja_template}"
)
else:
LOG.warning(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)
raise ValueError(
f"Unsupported chat template: {chat_template} with {chat_template_jinja}"
)

View File

@@ -0,0 +1,159 @@
"""
tests for jinja_template_analyzer
"""
import logging
import pytest
from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnalyzer
logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger("axolotl")
class TestJinjaTemplateAnalyzer:
"""
tests for jinja_template_analyzer
"""
def test_basic_variable_extraction(self, basic_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = basic_jinja_template_analyzer.get_template_variables()
expected_vars = {"messages", "add_generation_prompt", "eos_token", "message"}
assert set(variables.keys()) == expected_vars
def test_mixtral_variable_extraction(self, mistral_jinja_template_analyzer):
"""Test that all top-level variables are correctly extracted."""
LOG.info("Testing with train_on_inputs=True")
variables = mistral_jinja_template_analyzer.get_template_variables()
expected_vars = {
"messages",
"content",
"eos_token",
"message",
"tools",
"system_message",
"loop_messages",
"ns",
"tool_call",
"tool",
"loop",
"bos_token",
"raise_exception",
}
assert set(variables.keys()) == expected_vars
message_vars = variables["message"]
assert message_vars == {"role", "content", "tool_calls", "tool_call_id"}
def test_message_property_access(self, basic_jinja_template_analyzer):
"""Test that properties accessed on 'message' variable are correctly identified."""
LOG.info("Testing message property access")
variables = basic_jinja_template_analyzer.get_template_variables()
assert "messages" in variables
assert "message" in variables
assert "role" in variables["message"]
assert "content" in variables["message"]
def test_detailed_analysis(self, basic_jinja_template_analyzer):
"""Test the detailed analysis of variable usage."""
LOG.info("Testing detailed analysis")
analysis = basic_jinja_template_analyzer.analyze_template()
assert analysis["messages"]["is_iterated"] is True
assert "role" in analysis["message"]["accessed_properties"]
assert "content" in analysis["message"]["accessed_properties"]
assert analysis["add_generation_prompt"]["is_conditional"] is True
assert len(analysis["add_generation_prompt"]["accessed_properties"]) == 0
assert not analysis["eos_token"]["is_iterated"]
assert len(analysis["eos_token"]["accessed_properties"]) == 0
def test_nested_property_access(self):
"""Test handling of nested property access."""
LOG.info("Testing nested property access")
template = """{{ user.profile.name }}{{ user.settings['preference'] }}"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "profile" in variables["user"]
assert "settings" in variables["user"]
def test_loop_variable_handling(self):
"""Test handling of loop variables and their properties."""
LOG.info("Testing loop variable handling")
template = """
{% for item in items %}
{{ item.name }}
{% for subitem in item.subitems %}
{{ subitem.value }}
{% endfor %}
{% endfor %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["items"]["is_iterated"]
assert "name" in analysis["item"]["accessed_properties"]
assert "subitems" in analysis["item"]["accessed_properties"]
def test_conditional_variable_usage(self):
"""Test detection of variables used in conditional statements."""
LOG.info("Testing conditional variable usage")
template = """
{% if user.is_admin and config.debug_mode %}
{{ debug_info }}
{% endif %}
"""
analyzer = JinjaTemplateAnalyzer(template)
analysis = analyzer.analyze_template()
assert analysis["user"]["is_conditional"]
assert analysis["config"]["is_conditional"]
assert "is_admin" in analysis["user"]["accessed_properties"]
assert "debug_mode" in analysis["config"]["accessed_properties"]
def test_complex_expressions(self):
"""Test handling of complex expressions and filters."""
LOG.info("Testing complex expressions and filters")
template = """
{{ user.name | upper }}
{{ messages | length > 0 and messages[0].content }}
{{ data['key'].nested['value'] }}
"""
analyzer = JinjaTemplateAnalyzer(template)
variables = analyzer.get_template_variables()
assert "user" in variables
assert "name" in variables["user"]
assert "messages" in variables
assert "content" in variables["messages"]
assert "data" in variables
def test_basic_msg_vars(self, basic_jinja_template_analyzer):
"""Test that the basic message variables are correctly identified."""
LOG.info("Testing basic message variables")
variables = basic_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content"}
def test_mixtral_msg_vars(self, mistral_jinja_template_analyzer):
"""Test that the mixtral message variables are correctly identified."""
LOG.info("Testing mixtral message variables")
variables = mistral_jinja_template_analyzer.get_message_vars()
assert variables == {"role", "content", "tool_calls", "tool_call_id"}
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -302,3 +302,22 @@ class TestValidationCheckDatasetConfig(BaseValidation):
)
validate_config(cfg)
def test_message_property_mappings(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"message_property_mappings": {
"role": "role",
"content": "content",
},
}
],
}
)
validate_config(cfg)

View File

@@ -76,7 +76,7 @@ class TestModelsUtils:
mocked_load_model_config.return_value = {}
with pytest.raises(ValueError) as exc:
# Should error before hitting tokenizer, so we pass in an empty str
load_model(cfg, tokenizer="")
load_model(cfg, tokenizer="") # type: ignore
assert (
"shifted-sparse attention does not currently support sample packing"
in str(exc.value)
@@ -116,3 +116,79 @@ class TestModelsUtils:
assert self.model_loader.model_kwargs.get(
"quantization_config", BitsAndBytesConfig
)
def test_message_property_mapping(self):
"""Test message property mapping configuration validation"""
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
# Test legacy fields are mapped orrectly
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_field_content="content_field",
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test direct message_property_mapping works
dataset = SFTDataset(
path="test_path",
message_property_mappings={
"role": "custom_role",
"content": "custom_content",
},
)
assert dataset.message_property_mappings == {
"role": "custom_role",
"content": "custom_content",
}
# Test both legacy and new fields work when they match
dataset = SFTDataset(
path="test_path",
message_field_role="same_role",
message_property_mappings={"role": "same_role"},
)
assert dataset.message_property_mappings == {
"role": "same_role",
"content": "content",
}
# Test both legacy and new fields work when they don't overlap
dataset = SFTDataset(
path="test_path",
message_field_role="role_field",
message_property_mappings={"content": "content_field"},
)
assert dataset.message_property_mappings == {
"role": "role_field",
"content": "content_field",
}
# Test no role or content provided
dataset = SFTDataset(
path="test_path",
)
assert dataset.message_property_mappings == {
"role": "role",
"content": "content",
}
# Test error when legacy and new fields conflict
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_role="legacy_role",
message_property_mappings={"role": "different_role"},
)
assert "Conflicting message role fields" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info:
SFTDataset(
path="test_path",
message_field_content="legacy_content",
message_property_mappings={"content": "different_content"},
)
assert "Conflicting message content fields" in str(exc_info.value)