Files
axolotl/tests/utils/schemas/validation/test_config_validators.py
Edward Zion Saji 55a7950e3d fix: DPO tool role KeyError (#3217), dataset hash output_dir (#3303), config validators (#3538) [skip ci]
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]

- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
  argilla_chat() so datasets with tool-call messages no longer raise
  KeyError: 'tool' (closes #3217)

- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
  overrides content instead of tokenizer.name_or_path when added_tokens_overrides
  is set, preventing cache busting when only output_dir changes (closes #3303)

- Add three Pydantic config validators to AxolotlConfigWCapabilities:
  * save_strategy: 'best' requires metric_for_best_model
  * streaming=True is incompatible with val_set_size > 0
  * lora_target_modules list entries must be valid Python regex patterns

- Tests for all three changes

* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash

* chore: lint

* move the validators out of the w/ capabilities schema

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2026-04-01 19:57:07 -04:00

120 lines
4.5 KiB
Python

"""
Tests for new config validators added to AxolotlInputConfig.
Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
"""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestSaveStrategyBestValidator:
"""save_strategy: 'best' must be accompanied by metric_for_best_model."""
def test_save_strategy_best_without_metric_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="best")
with pytest.raises(ValueError, match="metric_for_best_model"):
validate_config(cfg)
def test_save_strategy_best_with_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
save_strategy="best",
metric_for_best_model="eval_loss",
)
validated = validate_config(cfg)
assert validated.save_strategy == "best"
assert validated.metric_for_best_model == "eval_loss"
def test_save_strategy_epoch_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="epoch")
validated = validate_config(cfg)
assert validated.save_strategy == "epoch"
def test_save_strategy_no_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="no")
validated = validate_config(cfg)
assert validated.save_strategy == "no"
def test_save_strategy_unset_without_metric_passes(self, min_base_cfg):
"""The default (None / not set) should not require metric_for_best_model."""
validated = validate_config(min_base_cfg)
assert validated.save_strategy is None
class TestStreamingWithValSetSizeValidator:
"""streaming=True is incompatible with val_set_size > 0."""
def test_streaming_with_val_set_size_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.1, max_steps=100
)
with pytest.raises(ValueError, match="val_set_size"):
validate_config(cfg)
def test_streaming_with_val_set_size_zero_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.0, max_steps=100
)
validated = validate_config(cfg)
assert validated.streaming is True
def test_streaming_false_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(streaming=False, val_set_size=0.1)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.1)
def test_streaming_unset_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(val_set_size=0.2)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.2)
class TestLoraTargetModulesRegexValidator:
"""lora_target_modules entries must be valid Python regex patterns."""
def test_invalid_regex_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "[invalid_regex"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
def test_valid_regex_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "v_proj", r".*_proj"],
)
validated = validate_config(cfg)
assert "q_proj" in validated.lora_target_modules
def test_plain_module_names_pass(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
validated = validate_config(cfg)
assert len(validated.lora_target_modules) == 4
def test_lora_target_linear_string_not_validated(self, min_base_cfg):
"""When lora_target_modules is a string (e.g. 'all-linear'), skip regex check."""
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules="all-linear",
)
# Should not raise
validate_config(cfg)
def test_multiple_invalid_patterns_reported(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["[bad1", "[bad2"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)