fix: DPO tool role KeyError (#3217), dataset hash output_dir (#3303), config validators (#3538) [skip ci]

* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]

- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
  argilla_chat() so datasets with tool-call messages no longer raise
  KeyError: 'tool' (closes #3217)

- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
  overrides content instead of tokenizer.name_or_path when added_tokens_overrides
  is set, preventing cache busting when only output_dir changes (closes #3303)

- Add three Pydantic config validators to AxolotlConfigWCapabilities:
  * save_strategy: 'best' requires metric_for_best_model
  * streaming=True is incompatible with val_set_size > 0
  * lora_target_modules list entries must be valid Python regex patterns

- Tests for all three changes

* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash

* chore: lint

* move the validators out of the w/ capabilities schema

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Edward Zion Saji
2026-04-02 05:27:07 +05:30
committed by GitHub
parent c92b71bd0c
commit 55a7950e3d
6 changed files with 383 additions and 1 deletions

View File

@@ -0,0 +1,135 @@
"""
Tests for generate_dataset_hash_from_config.
Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3303:
changing output_dir should not bust the dataset cache when added_tokens_overrides
is set.
"""
from axolotl.utils.data.shared import generate_dataset_hash_from_config
from axolotl.utils.dict import DictDefault
def _base_cfg(**kwargs):
return DictDefault(
{
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"group_by_length": False,
"kd_temperature": None,
"dataset_exact_deduplication": False,
"tokenizer_config": "NousResearch/Llama-3.2-1B",
**kwargs,
}
)
def _datasets():
return [
DictDefault(
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"shards": None,
"conversation": None,
"split": "train",
"temperature": None,
}
)
]
class TestGenerateDatasetHashFromConfig:
def test_same_config_same_hash(self):
"""Identical configs produce identical hashes."""
cfg = _base_cfg()
h1 = generate_dataset_hash_from_config(
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
)
h2 = generate_dataset_hash_from_config(
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
)
assert h1 == h2
def test_different_tokenizer_different_hash(self):
"""A different tokenizer path produces a different hash."""
cfg = _base_cfg()
h1 = generate_dataset_hash_from_config(
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
)
h2 = generate_dataset_hash_from_config(
cfg, _datasets(), "HuggingFaceTB/SmolLM2-135M"
)
assert h1 != h2
def test_different_sequence_len_different_hash(self):
cfg_a = _base_cfg(sequence_len=2048)
cfg_b = _base_cfg(sequence_len=4096)
h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), "tok")
h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), "tok")
assert h1 != h2
# --- Regression: added_tokens_overrides + output_dir ---
def test_added_tokens_overrides_hash_stable_across_output_dir(self):
"""Hash must not change when only output_dir changes (issue #3303).
When added_tokens_overrides is set the tokenizer is saved into output_dir,
making tokenizer.name_or_path an absolute path that includes output_dir.
The hash should be derived from the canonical tokenizer config + overrides,
not from the output-dir-dependent path.
"""
cfg_run1 = _base_cfg(
output_dir="/tmp/run_1",
added_tokens_overrides={32000: "<PAD>", 32001: "<MASK>"},
)
cfg_run2 = _base_cfg(
output_dir="/tmp/run_2_different_name",
added_tokens_overrides={32000: "<PAD>", 32001: "<MASK>"},
)
# Simulate what happens in practice: tokenizer.name_or_path becomes the
# output_dir-based path after modify_tokenizer_files() saves the tokenizer.
tokenizer_name_run1 = "/tmp/run_1/modified_tokenizer"
tokenizer_name_run2 = "/tmp/run_2_different_name/modified_tokenizer"
h1 = generate_dataset_hash_from_config(
cfg_run1, _datasets(), tokenizer_name_run1
)
h2 = generate_dataset_hash_from_config(
cfg_run2, _datasets(), tokenizer_name_run2
)
assert h1 == h2, (
"Dataset cache hash must not change when only output_dir changes "
"while added_tokens_overrides stays the same (issue #3303)."
)
def test_added_tokens_overrides_different_overrides_different_hash(self):
"""Different added_tokens_overrides produce different hashes."""
cfg_a = _base_cfg(
output_dir="/tmp/run_a",
added_tokens_overrides={32000: "<PAD>"},
)
cfg_b = _base_cfg(
output_dir="/tmp/run_a", # same output_dir
added_tokens_overrides={32000: "<OTHER>"},
)
tokenizer_path = "/tmp/run_a/modified_tokenizer"
h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), tokenizer_path)
h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), tokenizer_path)
assert h1 != h2
def test_no_added_tokens_overrides_uses_tokenizer_name_as_before(self):
"""Without added_tokens_overrides the old behaviour is preserved."""
cfg = _base_cfg() # no added_tokens_overrides
tokenizer_name = "NousResearch/Llama-3.2-1B"
h1 = generate_dataset_hash_from_config(cfg, _datasets(), tokenizer_name)
# Changing tokenizer_name still changes the hash
h2 = generate_dataset_hash_from_config(cfg, _datasets(), "some/other-model")
assert h1 != h2

View File

@@ -0,0 +1,119 @@
"""
Tests for new config validators added to AxolotlInputConfig.
Covers:
- save_strategy: 'best' requires metric_for_best_model
- streaming=True with val_set_size > 0 is rejected
- lora_target_modules with invalid regex patterns is rejected
"""
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
class TestSaveStrategyBestValidator:
"""save_strategy: 'best' must be accompanied by metric_for_best_model."""
def test_save_strategy_best_without_metric_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="best")
with pytest.raises(ValueError, match="metric_for_best_model"):
validate_config(cfg)
def test_save_strategy_best_with_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
save_strategy="best",
metric_for_best_model="eval_loss",
)
validated = validate_config(cfg)
assert validated.save_strategy == "best"
assert validated.metric_for_best_model == "eval_loss"
def test_save_strategy_epoch_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="epoch")
validated = validate_config(cfg)
assert validated.save_strategy == "epoch"
def test_save_strategy_no_without_metric_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(save_strategy="no")
validated = validate_config(cfg)
assert validated.save_strategy == "no"
def test_save_strategy_unset_without_metric_passes(self, min_base_cfg):
"""The default (None / not set) should not require metric_for_best_model."""
validated = validate_config(min_base_cfg)
assert validated.save_strategy is None
class TestStreamingWithValSetSizeValidator:
"""streaming=True is incompatible with val_set_size > 0."""
def test_streaming_with_val_set_size_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.1, max_steps=100
)
with pytest.raises(ValueError, match="val_set_size"):
validate_config(cfg)
def test_streaming_with_val_set_size_zero_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
streaming=True, val_set_size=0.0, max_steps=100
)
validated = validate_config(cfg)
assert validated.streaming is True
def test_streaming_false_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(streaming=False, val_set_size=0.1)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.1)
def test_streaming_unset_with_val_set_size_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(val_set_size=0.2)
validated = validate_config(cfg)
assert validated.val_set_size == pytest.approx(0.2)
class TestLoraTargetModulesRegexValidator:
"""lora_target_modules entries must be valid Python regex patterns."""
def test_invalid_regex_raises(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "[invalid_regex"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)
def test_valid_regex_passes(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "v_proj", r".*_proj"],
)
validated = validate_config(cfg)
assert "q_proj" in validated.lora_target_modules
def test_plain_module_names_pass(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
validated = validate_config(cfg)
assert len(validated.lora_target_modules) == 4
def test_lora_target_linear_string_not_validated(self, min_base_cfg):
"""When lora_target_modules is a string (e.g. 'all-linear'), skip regex check."""
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules="all-linear",
)
# Should not raise
validate_config(cfg)
def test_multiple_invalid_patterns_reported(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
adapter="lora",
lora_target_modules=["[bad1", "[bad2"],
)
with pytest.raises(ValueError, match="invalid regex pattern"):
validate_config(cfg)