diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 81f3021e8..7290b948d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -139,10 +139,13 @@ class SFTDataset(BaseModel): type: Optional[Union[str, UserDefinedPrompterType]] = None shards: Optional[int] = None conversation: Optional[str] = None - chat_template: Union[ - ChatTemplate, - Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], - ] = ChatTemplate.tokenizer_default + # Do not make this too strict or it will break the validator to choose different dataset class + chat_template: Optional[ + Union[ + ChatTemplate, + str, + ] + ] = None chat_template_jinja: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None name: Optional[str] = None @@ -165,6 +168,10 @@ class SFTDataset(BaseModel): @model_validator(mode="before") @classmethod def check_chat_template_config(cls, data): + # Set chat_template to tokenizer_default if not set + if data.get("type") == "chat_template" and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.tokenizer_default + # if chat_template is set to jinja, chat_template_jinja is required if data.get("chat_template") == ChatTemplate.jinja and not data.get( "chat_template_jinja" @@ -735,10 +742,12 @@ class AxolotlInputConfig( gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None - chat_template: Union[ - ChatTemplate, - Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], - ] = ChatTemplate.tokenizer_default + chat_template: Optional[ + Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] + ] = None chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None @@ -781,6 +790,20 @@ class AxolotlInputConfig( ) return datasets + @model_validator(mode="before") + @classmethod + def set_default_chat_template(cls, data): + if data.get("chat_template") is None: + use_chat_template = any( + dataset["type"] == "chat_template" + for dataset in data.get("datasets", []) + ) + + if use_chat_template: + data["chat_template"] = ChatTemplate.tokenizer_default + + return data + @model_validator(mode="before") @classmethod def check_batch_size_fields(cls, data): diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py new file mode 100644 index 000000000..f889c9813 --- /dev/null +++ b/tests/test_validation_dataset.py @@ -0,0 +1,237 @@ +"""Module for testing the validation module for the dataset config""" + +import warnings +from typing import Optional + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate +from axolotl.utils.dict import DictDefault + +warnings.filterwarnings("error") + + +@pytest.fixture(name="minimal_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + +class BaseValidation: + """ + Base validation module to setup the log capture + """ + + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + +class TestValidationCheckDatasetConfig(BaseValidation): + """ + Test the validation for the dataset config to ensure no correct parameters are dropped + """ + + def test_dataset_config_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "sharegpt", + "conversation": "chatml", + "shards": 10, + } + ] + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == ChatTemplate.tokenizer_default + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == ChatTemplate.chatml + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "chat_template": "gemma", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == cfg.chat_template + assert ( + checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config()