fix: config being dropped and unittest to catch that
This commit is contained in:
@@ -139,10 +139,13 @@ class SFTDataset(BaseModel):
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
chat_template: Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
] = ChatTemplate.tokenizer_default
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
str,
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
data_files: Optional[Union[str, List[str]]] = None
|
||||
name: Optional[str] = None
|
||||
@@ -165,6 +168,10 @@ class SFTDataset(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
@@ -735,10 +742,12 @@ class AxolotlInputConfig(
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
chat_template: Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
] = ChatTemplate.tokenizer_default
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
@@ -781,6 +790,20 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return datasets
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_default_chat_template(cls, data):
|
||||
if data.get("chat_template") is None:
|
||||
use_chat_template = any(
|
||||
dataset["type"] == "chat_template"
|
||||
for dataset in data.get("datasets", [])
|
||||
)
|
||||
|
||||
if use_chat_template:
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_batch_size_fields(cls, data):
|
||||
|
||||
237
tests/test_validation_dataset.py
Normal file
237
tests/test_validation_dataset.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Module for testing the validation module for the dataset config"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
"""
|
||||
Test the validation for the dataset config to ensure no correct parameters are dropped
|
||||
"""
|
||||
|
||||
def test_dataset_config_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "sharegpt",
|
||||
"conversation": "chatml",
|
||||
"shards": 10,
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template == ChatTemplate.tokenizer_default
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template == ChatTemplate.chatml
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "chat_template",
|
||||
"chat_template": "gemma",
|
||||
"field_messages": "conversations",
|
||||
"shards": 10,
|
||||
"message_field_role": "from",
|
||||
"message_field_content": "value",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
checked_cfg = validate_config(cfg)
|
||||
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.chat_template == cfg.chat_template
|
||||
assert (
|
||||
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||
)
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_role
|
||||
== cfg.datasets[0].message_field_role
|
||||
)
|
||||
assert (
|
||||
checked_cfg.datasets[0].message_field_content
|
||||
== cfg.datasets[0].message_field_content
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
checked_cfg = validate_config(
|
||||
cfg,
|
||||
capabilities={
|
||||
"bf16": "false",
|
||||
"n_gpu": 1,
|
||||
"compute_capability": "8.0",
|
||||
},
|
||||
)
|
||||
|
||||
_check_config()
|
||||
Reference in New Issue
Block a user