Files
axolotl/tests/test_validation_dataset.py
salman bbd3486f57 Distributed Muon Optimizer (#3264)
* init

* working

* updating configs

* removing unneeded files

* lint

* comments

* lint

* fix regex match

* bump contribs version

* comments

* fixing tests and imports

* muon imports in test v2

* test cleanup

* bump contribs version

---------

Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
2025-12-19 10:43:47 -05:00

368 lines
11 KiB
Python

"""Module for testing the validation module for the dataset config"""
import warnings
from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.schemas.datasets import ChatTemplate
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg")
def fixture_cfg():
return DictDefault(
{
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
"learning_rate": 0.000001,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class BaseValidation:
"""
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
class TestValidationCheckDatasetConfig(BaseValidation):
"""
Test the validation for the dataset config to ensure no correct parameters are dropped
"""
def test_dataset_config_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"shards": 10,
}
]
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.6.0",
},
)
_check_config()
def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template is None
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.6.0",
},
)
_check_config()
def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == ChatTemplate.chatml
assert (
checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.6.0",
},
)
_check_config()
def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "chat_template",
"chat_template": "gemma",
"field_messages": "conversations",
"shards": 10,
"message_field_role": "from",
"message_field_content": "value",
}
],
}
)
checked_cfg = validate_config(cfg)
def _check_config():
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
assert checked_cfg.chat_template == cfg.chat_template
assert (
checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template
)
assert (
checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages
)
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
assert (
checked_cfg.datasets[0].message_field_role
== cfg.datasets[0].message_field_role
)
assert (
checked_cfg.datasets[0].message_field_content
== cfg.datasets[0].message_field_content
)
_check_config()
checked_cfg = validate_config(
cfg,
capabilities={
"bf16": "false",
"n_gpu": 1,
"compute_capability": "8.0",
},
env_capabilities={
"torch_version": "2.6.0",
},
)
_check_config()
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"chat_template": "chatml",
"datasets": [
{
"path": "LDJnr/Puffin",
"type": "sharegpt",
"conversation": "chatml",
}
],
}
)
# Check sharegpt deprecation is raised
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
validate_config(cfg)
# Check that deprecation is not thrown for non-str type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": {
"field_instruction": "instruction",
"field_output": "output",
"field_system": "system",
"format": "<|user|> {instruction} {input} <|model|>",
"no_input_format": "<|user|> {instruction} <|model|>",
"system_prompt": "",
},
}
],
}
)
validate_config(cfg)
# Check that deprecation is not thrown for non-sharegpt type
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
}
)
validate_config(cfg)
def test_message_property_mappings(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"message_property_mappings": {
"role": "role",
"content": "content",
},
}
],
}
)
validate_config(cfg)
class TestOptimizerValidation(BaseValidation):
"""
Test muon optimizer validation
"""
def test_muon_deepspeed(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"deepspeed": "deepspeed_configs/zero3.json",
}
)
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)
def test_muon_fsdp(self, minimal_cfg):
cfg = DictDefault(
minimal_cfg
| {
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"optimizer": "muon",
"fsdp": ["full_shard"],
"fsdp_config": {
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
},
}
)
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
validate_config(cfg)