Files
axolotl/tests/utils/schemas/validation/test_config_validators.py
2026-04-22 09:05:46 -04:00

255 lines
9.3 KiB
Python

"""
Tests for new config validators added to AxolotlInputConfig.
Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
- GRPO: generation batch size must be divisible by num_generations,
num_generations >= 2, and effective_gbs >= num_generations * world_size
"""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestSaveStrategyBestValidator:
"""save_strategy: 'best' must be accompanied by metric_for_best_model."""
def test_save_strategy_best_without_metric_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="best")
with pytest.raises(ValueError, match="metric_for_best_model"):
validate_config(cfg)
def test_save_strategy_best_with_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
save_strategy="best",
metric_for_best_model="eval_loss",
)
validated = validate_config(cfg)
assert validated.save_strategy == "best"
assert validated.metric_for_best_model == "eval_loss"
def test_save_strategy_epoch_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="epoch")
validated = validate_config(cfg)
assert validated.save_strategy == "epoch"
def test_save_strategy_no_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="no")
validated = validate_config(cfg)
assert validated.save_strategy == "no"
def test_save_strategy_unset_without_metric_passes(self, min_base_cfg):
"""The default (None / not set) should not require metric_for_best_model."""
validated = validate_config(min_base_cfg)
assert validated.save_strategy is None
class TestStreamingWithValSetSizeValidator:
"""streaming=True is incompatible with val_set_size > 0."""
def test_streaming_with_val_set_size_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.1, max_steps=100
)
with pytest.raises(ValueError, match="val_set_size"):
validate_config(cfg)
def test_streaming_with_val_set_size_zero_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.0, max_steps=100
)
validated = validate_config(cfg)
assert validated.streaming is True
def test_streaming_false_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(streaming=False, val_set_size=0.1)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.1)
def test_streaming_unset_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(val_set_size=0.2)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.2)
class TestLoraTargetModulesRegexValidator:
"""lora_target_modules entries must be valid Python regex patterns."""
def test_invalid_regex_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "[invalid_regex"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
def test_valid_regex_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "v_proj", r".*_proj"],
)
validated = validate_config(cfg)
assert "q_proj" in validated.lora_target_modules
def test_plain_module_names_pass(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
validated = validate_config(cfg)
assert len(validated.lora_target_modules) == 4
def test_lora_target_linear_string_not_validated(self, min_base_cfg):
"""When lora_target_modules is a string (e.g. 'all-linear'), skip regex check."""
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules="all-linear",
)
# Should not raise
validate_config(cfg)
def test_multiple_invalid_patterns_reported(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["[bad1", "[bad2"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
class TestGRPOBatchSizeValidator:
"""GRPO requires (mb*GA) % num_generations == 0 and num_generations >= 2.
These call the @model_validator(mode="before") classmethod directly on a
plain dict — same input shape it receives during full Pydantic validation,
just without dragging in unrelated fields (datasets / model loading / etc.)
that aren't relevant to what's under test. The validator is registered on
``RLValidationMixin`` (which ``AxolotlInputConfig`` inherits) so this is the
same code path ``axolotl train`` exercises.
"""
@staticmethod
def _check(data):
from axolotl.utils.schemas.validation import RLValidationMixin
return RLValidationMixin.check_grpo_batch_size_divisibility(data)
def test_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 4},
}
# Should return data unchanged (no exception)
out = self._check(data)
assert out["trl"]["num_generations"] == 4
def test_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 2,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="num_generations"):
self._check(data)
def test_non_divisible_error_includes_fix_hint(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match="gradient_accumulation_steps: 4"):
self._check(data)
def test_num_generations_one_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4,
"trl": {"num_generations": 1},
}
with pytest.raises(ValueError, match=r"num_generations >= 2"):
self._check(data)
def test_explicit_generation_batch_size_divisible_passes(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 8},
}
out = self._check(data)
assert out["trl"]["generation_batch_size"] == 8
def test_explicit_generation_batch_size_non_divisible_raises(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"trl": {"num_generations": 4, "generation_batch_size": 6},
}
with pytest.raises(ValueError, match="trl.generation_batch_size"):
self._check(data)
def test_non_grpo_skips_check(self):
# Anything other than rl=grpo should pass through untouched, even
# with non-divisible batch sizes — they're irrelevant to other RL
# methods that don't use group-relative advantages.
data = {
"rl": "dpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {"num_generations": 4},
}
assert self._check(data) is data
def test_no_rl_set_skips_check(self):
data = {
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
}
assert self._check(data) is data
def test_grpo_without_num_generations_skips_check(self):
# If num_generations isn't set, TRL uses its own default — we don't
# have enough info to validate, so the validator must short-circuit
# rather than guess.
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 3,
"trl": {},
}
out = self._check(data)
assert out["rl"] == "grpo"
def test_multi_rank_group_size_check(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 4, # gbs=4
"world_size": 2, # need gbs >= 4*2 = 8
"trl": {"num_generations": 4},
}
with pytest.raises(ValueError, match=r"world_size=2"):
self._check(data)
def test_multi_rank_group_size_satisfied(self):
data = {
"rl": "grpo",
"micro_batch_size": 1,
"gradient_accumulation_steps": 8, # gbs=8 >= 4*2
"world_size": 2,
"trl": {"num_generations": 4},
}
out = self._check(data)
assert out["gradient_accumulation_steps"] == 8