"""Module for testing the validation module for the dataset config""" import warnings from typing import Optional import pytest from axolotl.utils.config import validate_config from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate from axolotl.utils.dict import DictDefault warnings.filterwarnings("error") @pytest.fixture(name="minimal_cfg") def fixture_cfg(): return DictDefault( { "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", "learning_rate": 0.000001, "micro_batch_size": 1, "gradient_accumulation_steps": 1, } ) # pylint: disable=too-many-public-methods (duplicate-code) class BaseValidation: """ Base validation module to setup the log capture """ _caplog: Optional[pytest.LogCaptureFixture] = None @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog class TestValidationCheckDatasetConfig(BaseValidation): """ Test the validation for the dataset config to ensure no correct parameters are dropped """ def test_dataset_config_no_drop_param(self, minimal_cfg): cfg = DictDefault( minimal_cfg | { "datasets": [ { "path": "mhenrichsen/alpaca_2k_test", "type": "alpaca", "shards": 10, } ] } ) checked_cfg = validate_config(cfg) def _check_config(): assert checked_cfg.datasets[0].path == cfg.datasets[0].path assert checked_cfg.datasets[0].type == cfg.datasets[0].type assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards _check_config() checked_cfg = validate_config( cfg, capabilities={ "bf16": "false", "n_gpu": 1, "compute_capability": "8.0", }, ) _check_config() def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg): cfg = DictDefault( minimal_cfg | { "datasets": [ { "path": "LDJnr/Puffin", "type": "chat_template", "field_messages": "conversations", "shards": 10, "message_field_role": "from", "message_field_content": "value", } ], } ) checked_cfg = validate_config(cfg) def _check_config(): assert checked_cfg.datasets[0].path == cfg.datasets[0].path assert checked_cfg.datasets[0].type == cfg.datasets[0].type assert checked_cfg.chat_template is None assert ( checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default ) assert ( checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages ) assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards assert ( checked_cfg.datasets[0].message_field_role == cfg.datasets[0].message_field_role ) assert ( checked_cfg.datasets[0].message_field_content == cfg.datasets[0].message_field_content ) _check_config() checked_cfg = validate_config( cfg, capabilities={ "bf16": "false", "n_gpu": 1, "compute_capability": "8.0", }, ) _check_config() def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg): cfg = DictDefault( minimal_cfg | { "chat_template": "chatml", "datasets": [ { "path": "LDJnr/Puffin", "type": "chat_template", "field_messages": "conversations", "shards": 10, "message_field_role": "from", "message_field_content": "value", } ], } ) checked_cfg = validate_config(cfg) def _check_config(): assert checked_cfg.datasets[0].path == cfg.datasets[0].path assert checked_cfg.datasets[0].type == cfg.datasets[0].type assert checked_cfg.chat_template == ChatTemplate.chatml assert ( checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default ) assert ( checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages ) assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards assert ( checked_cfg.datasets[0].message_field_role == cfg.datasets[0].message_field_role ) assert ( checked_cfg.datasets[0].message_field_content == cfg.datasets[0].message_field_content ) _check_config() checked_cfg = validate_config( cfg, capabilities={ "bf16": "false", "n_gpu": 1, "compute_capability": "8.0", }, ) _check_config() def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg): cfg = DictDefault( minimal_cfg | { "chat_template": "chatml", "datasets": [ { "path": "LDJnr/Puffin", "type": "chat_template", "chat_template": "gemma", "field_messages": "conversations", "shards": 10, "message_field_role": "from", "message_field_content": "value", } ], } ) checked_cfg = validate_config(cfg) def _check_config(): assert checked_cfg.datasets[0].path == cfg.datasets[0].path assert checked_cfg.datasets[0].type == cfg.datasets[0].type assert checked_cfg.chat_template == cfg.chat_template assert ( checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template ) assert ( checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages ) assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards assert ( checked_cfg.datasets[0].message_field_role == cfg.datasets[0].message_field_role ) assert ( checked_cfg.datasets[0].message_field_content == cfg.datasets[0].message_field_content ) _check_config() checked_cfg = validate_config( cfg, capabilities={ "bf16": "false", "n_gpu": 1, "compute_capability": "8.0", }, ) _check_config()