fix: config being dropped and unittest to catch that
This commit is contained in:
@@ -139,10 +139,13 @@ class SFTDataset(BaseModel):
|
|||||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||||
shards: Optional[int] = None
|
shards: Optional[int] = None
|
||||||
conversation: Optional[str] = None
|
conversation: Optional[str] = None
|
||||||
chat_template: Union[
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
ChatTemplate,
|
chat_template: Optional[
|
||||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
Union[
|
||||||
] = ChatTemplate.tokenizer_default
|
ChatTemplate,
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
] = None
|
||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
data_files: Optional[Union[str, List[str]]] = None
|
data_files: Optional[Union[str, List[str]]] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
@@ -165,6 +168,10 @@ class SFTDataset(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_chat_template_config(cls, data):
|
def check_chat_template_config(cls, data):
|
||||||
|
# Set chat_template to tokenizer_default if not set
|
||||||
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
# if chat_template is set to jinja, chat_template_jinja is required
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
"chat_template_jinja"
|
"chat_template_jinja"
|
||||||
@@ -735,10 +742,12 @@ class AxolotlInputConfig(
|
|||||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||||
low_cpu_mem_usage: Optional[bool] = None
|
low_cpu_mem_usage: Optional[bool] = None
|
||||||
|
|
||||||
chat_template: Union[
|
chat_template: Optional[
|
||||||
ChatTemplate,
|
Union[
|
||||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
ChatTemplate,
|
||||||
] = ChatTemplate.tokenizer_default
|
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||||
|
]
|
||||||
|
] = None
|
||||||
chat_template_jinja: Optional[str] = None
|
chat_template_jinja: Optional[str] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
@@ -781,6 +790,20 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_default_chat_template(cls, data):
|
||||||
|
if data.get("chat_template") is None:
|
||||||
|
use_chat_template = any(
|
||||||
|
dataset["type"] == "chat_template"
|
||||||
|
for dataset in data.get("datasets", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_chat_template:
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_batch_size_fields(cls, data):
|
def check_batch_size_fields(cls, data):
|
||||||
|
|||||||
237
tests/test_validation_dataset.py
Normal file
237
tests/test_validation_dataset.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
"""Module for testing the validation module for the dataset config"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="minimal_cfg")
|
||||||
|
def fixture_cfg():
|
||||||
|
return DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseValidation:
|
||||||
|
"""
|
||||||
|
Base validation module to setup the log capture
|
||||||
|
"""
|
||||||
|
|
||||||
|
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, caplog):
|
||||||
|
self._caplog = caplog
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationCheckDatasetConfig(BaseValidation):
|
||||||
|
"""
|
||||||
|
Test the validation for the dataset config to ensure no correct parameters are dropped
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_dataset_config_no_drop_param(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "sharegpt",
|
||||||
|
"conversation": "chatml",
|
||||||
|
"shards": 10,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
checked_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
def _check_config():
|
||||||
|
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||||
|
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||||
|
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
|
||||||
|
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
checked_cfg = validate_config(
|
||||||
|
cfg,
|
||||||
|
capabilities={
|
||||||
|
"bf16": "false",
|
||||||
|
"n_gpu": 1,
|
||||||
|
"compute_capability": "8.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"shards": 10,
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
checked_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
def _check_config():
|
||||||
|
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||||
|
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||||
|
assert checked_cfg.chat_template == ChatTemplate.tokenizer_default
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||||
|
)
|
||||||
|
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_role
|
||||||
|
== cfg.datasets[0].message_field_role
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_content
|
||||||
|
== cfg.datasets[0].message_field_content
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
checked_cfg = validate_config(
|
||||||
|
cfg,
|
||||||
|
capabilities={
|
||||||
|
"bf16": "false",
|
||||||
|
"n_gpu": 1,
|
||||||
|
"compute_capability": "8.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"shards": 10,
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
checked_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
def _check_config():
|
||||||
|
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||||
|
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||||
|
assert checked_cfg.chat_template == ChatTemplate.chatml
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||||
|
)
|
||||||
|
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_role
|
||||||
|
== cfg.datasets[0].message_field_role
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_content
|
||||||
|
== cfg.datasets[0].message_field_content
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
checked_cfg = validate_config(
|
||||||
|
cfg,
|
||||||
|
capabilities={
|
||||||
|
"bf16": "false",
|
||||||
|
"n_gpu": 1,
|
||||||
|
"compute_capability": "8.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "chat_template",
|
||||||
|
"chat_template": "gemma",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"shards": 10,
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
checked_cfg = validate_config(cfg)
|
||||||
|
|
||||||
|
def _check_config():
|
||||||
|
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||||
|
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||||
|
assert checked_cfg.chat_template == cfg.chat_template
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
|
||||||
|
)
|
||||||
|
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_role
|
||||||
|
== cfg.datasets[0].message_field_role
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
checked_cfg.datasets[0].message_field_content
|
||||||
|
== cfg.datasets[0].message_field_content
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
|
|
||||||
|
checked_cfg = validate_config(
|
||||||
|
cfg,
|
||||||
|
capabilities={
|
||||||
|
"bf16": "false",
|
||||||
|
"n_gpu": 1,
|
||||||
|
"compute_capability": "8.0",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
_check_config()
|
||||||
Reference in New Issue
Block a user