fix: DPO tool role KeyError (#3217), dataset hash output_dir (#3303), config validators (#3538) [skip ci]
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e] - Add 'tool' to default role_map_inv in dpo/chat_template.py default() and argilla_chat() so datasets with tool-call messages no longer raise KeyError: 'tool' (closes #3217) - Fix generate_dataset_hash_from_config to use canonical tokenizer config + overrides content instead of tokenizer.name_or_path when added_tokens_overrides is set, preventing cache busting when only output_dir changes (closes #3303) - Add three Pydantic config validators to AxolotlConfigWCapabilities: * save_strategy: 'best' requires metric_for_best_model * streaming=True is incompatible with val_set_size > 0 * lora_target_modules list entries must be valid Python regex patterns - Tests for all three changes * review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash * chore: lint * move the validators out of the w/ capabilities schema --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -29,6 +29,7 @@ def default(cfg, dataset_idx=0, **kwargs):
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
"tool": ["tool"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
@@ -174,6 +175,7 @@ def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
"tool": ["tool"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
|
||||
@@ -516,12 +516,21 @@ def generate_dataset_hash_from_config(
|
||||
Returns:
|
||||
MD5 hash string representing the configuration.
|
||||
"""
|
||||
# When added_tokens_overrides is set, tokenizer.name_or_path contains output_dir.
|
||||
# Use the canonical tokenizer config + overrides content so the hash is stable across output_dir changes.
|
||||
if cfg.get("added_tokens_overrides"):
|
||||
tokenizer_fingerprint = f"{cfg.tokenizer_config}+overrides:" + ",".join(
|
||||
f"{k}={v}" for k, v in sorted(cfg.added_tokens_overrides.items())
|
||||
)
|
||||
else:
|
||||
tokenizer_fingerprint = tokenizer_name
|
||||
|
||||
config_str = (
|
||||
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
||||
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
|
||||
f"{cfg.dataset_exact_deduplication or False}|"
|
||||
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
||||
f"|{tokenizer_name}"
|
||||
f"|{tokenizer_fingerprint}"
|
||||
)
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
import re
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from accelerate.utils import is_fp8_available
|
||||
@@ -1338,6 +1339,39 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_save_strategy_best_requires_metric(cls, data):
|
||||
if data.get("save_strategy") == "best" and not data.get(
|
||||
"metric_for_best_model"
|
||||
):
|
||||
raise ValueError(
|
||||
"save_strategy: 'best' requires metric_for_best_model to be set. "
|
||||
"Please specify the metric to use, e.g. metric_for_best_model: eval_loss"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_target_modules_regex(cls, data):
|
||||
lora_target_modules = data.get("lora_target_modules")
|
||||
if not isinstance(lora_target_modules, list):
|
||||
return data
|
||||
invalid = []
|
||||
for pattern in lora_target_modules:
|
||||
if not isinstance(pattern, str):
|
||||
continue
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error:
|
||||
invalid.append(pattern)
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"lora_target_modules contains invalid regex pattern(s): {invalid}. "
|
||||
"Please provide valid Python regex patterns or plain module name strings."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""Wrapper to valdiate GPU capabilities with the configured options"""
|
||||
|
||||
@@ -294,5 +294,88 @@ class TestArgillaChatDPOChatTemplate:
|
||||
assert result["rejected"] == "party on<|end|>"
|
||||
|
||||
|
||||
class TestDPOChatTemplateToolRole:
|
||||
"""
|
||||
Test that DPO chat template strategy handles tool role messages without KeyError.
|
||||
Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3217
|
||||
"""
|
||||
|
||||
def test_tool_role_default_no_key_error(self, llama3_tokenizer):
|
||||
"""Messages list with a 'tool' role should not raise KeyError."""
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Let me check.",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "22°C, sunny.",
|
||||
},
|
||||
],
|
||||
"chosen": {
|
||||
"role": "assistant",
|
||||
"content": "It is 22°C and sunny.",
|
||||
},
|
||||
"rejected": {
|
||||
"role": "assistant",
|
||||
"content": "I don't know.",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [{"type": "chat_template"}],
|
||||
}
|
||||
)
|
||||
)
|
||||
# Should not raise KeyError: 'tool'
|
||||
result = transform_fn(dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert "prompt" in result
|
||||
assert "chosen" in result
|
||||
assert "rejected" in result
|
||||
|
||||
def test_tool_role_custom_mapping_preserved(self, llama3_tokenizer):
|
||||
"""A user-supplied roles mapping that overrides 'tool' is still respected."""
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool_result", "content": "42"},
|
||||
],
|
||||
"chosen": {"role": "assistant", "content": "The answer is 42."},
|
||||
"rejected": {"role": "assistant", "content": "Unknown."},
|
||||
}
|
||||
]
|
||||
)
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"type": "chat_template",
|
||||
"roles": {
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
"tool": ["tool_result"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
result = transform_fn(dataset[0], tokenizer=llama3_tokenizer)
|
||||
assert "prompt" in result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
135
tests/utils/data/test_hash.py
Normal file
135
tests/utils/data/test_hash.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Tests for generate_dataset_hash_from_config.
|
||||
|
||||
Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3303:
|
||||
changing output_dir should not bust the dataset cache when added_tokens_overrides
|
||||
is set.
|
||||
"""
|
||||
|
||||
from axolotl.utils.data.shared import generate_dataset_hash_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def _base_cfg(**kwargs):
|
||||
return DictDefault(
|
||||
{
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"group_by_length": False,
|
||||
"kd_temperature": None,
|
||||
"dataset_exact_deduplication": False,
|
||||
"tokenizer_config": "NousResearch/Llama-3.2-1B",
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _datasets():
|
||||
return [
|
||||
DictDefault(
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"shards": None,
|
||||
"conversation": None,
|
||||
"split": "train",
|
||||
"temperature": None,
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class TestGenerateDatasetHashFromConfig:
|
||||
def test_same_config_same_hash(self):
|
||||
"""Identical configs produce identical hashes."""
|
||||
cfg = _base_cfg()
|
||||
h1 = generate_dataset_hash_from_config(
|
||||
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
|
||||
)
|
||||
h2 = generate_dataset_hash_from_config(
|
||||
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
|
||||
)
|
||||
assert h1 == h2
|
||||
|
||||
def test_different_tokenizer_different_hash(self):
|
||||
"""A different tokenizer path produces a different hash."""
|
||||
cfg = _base_cfg()
|
||||
h1 = generate_dataset_hash_from_config(
|
||||
cfg, _datasets(), "NousResearch/Llama-3.2-1B"
|
||||
)
|
||||
h2 = generate_dataset_hash_from_config(
|
||||
cfg, _datasets(), "HuggingFaceTB/SmolLM2-135M"
|
||||
)
|
||||
assert h1 != h2
|
||||
|
||||
def test_different_sequence_len_different_hash(self):
|
||||
cfg_a = _base_cfg(sequence_len=2048)
|
||||
cfg_b = _base_cfg(sequence_len=4096)
|
||||
h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), "tok")
|
||||
h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), "tok")
|
||||
assert h1 != h2
|
||||
|
||||
# --- Regression: added_tokens_overrides + output_dir ---
|
||||
|
||||
def test_added_tokens_overrides_hash_stable_across_output_dir(self):
|
||||
"""Hash must not change when only output_dir changes (issue #3303).
|
||||
|
||||
When added_tokens_overrides is set the tokenizer is saved into output_dir,
|
||||
making tokenizer.name_or_path an absolute path that includes output_dir.
|
||||
The hash should be derived from the canonical tokenizer config + overrides,
|
||||
not from the output-dir-dependent path.
|
||||
"""
|
||||
cfg_run1 = _base_cfg(
|
||||
output_dir="/tmp/run_1",
|
||||
added_tokens_overrides={32000: "<PAD>", 32001: "<MASK>"},
|
||||
)
|
||||
cfg_run2 = _base_cfg(
|
||||
output_dir="/tmp/run_2_different_name",
|
||||
added_tokens_overrides={32000: "<PAD>", 32001: "<MASK>"},
|
||||
)
|
||||
|
||||
# Simulate what happens in practice: tokenizer.name_or_path becomes the
|
||||
# output_dir-based path after modify_tokenizer_files() saves the tokenizer.
|
||||
tokenizer_name_run1 = "/tmp/run_1/modified_tokenizer"
|
||||
tokenizer_name_run2 = "/tmp/run_2_different_name/modified_tokenizer"
|
||||
|
||||
h1 = generate_dataset_hash_from_config(
|
||||
cfg_run1, _datasets(), tokenizer_name_run1
|
||||
)
|
||||
h2 = generate_dataset_hash_from_config(
|
||||
cfg_run2, _datasets(), tokenizer_name_run2
|
||||
)
|
||||
|
||||
assert h1 == h2, (
|
||||
"Dataset cache hash must not change when only output_dir changes "
|
||||
"while added_tokens_overrides stays the same (issue #3303)."
|
||||
)
|
||||
|
||||
def test_added_tokens_overrides_different_overrides_different_hash(self):
|
||||
"""Different added_tokens_overrides produce different hashes."""
|
||||
cfg_a = _base_cfg(
|
||||
output_dir="/tmp/run_a",
|
||||
added_tokens_overrides={32000: "<PAD>"},
|
||||
)
|
||||
cfg_b = _base_cfg(
|
||||
output_dir="/tmp/run_a", # same output_dir
|
||||
added_tokens_overrides={32000: "<OTHER>"},
|
||||
)
|
||||
tokenizer_path = "/tmp/run_a/modified_tokenizer"
|
||||
|
||||
h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), tokenizer_path)
|
||||
h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), tokenizer_path)
|
||||
|
||||
assert h1 != h2
|
||||
|
||||
def test_no_added_tokens_overrides_uses_tokenizer_name_as_before(self):
|
||||
"""Without added_tokens_overrides the old behaviour is preserved."""
|
||||
cfg = _base_cfg() # no added_tokens_overrides
|
||||
tokenizer_name = "NousResearch/Llama-3.2-1B"
|
||||
|
||||
h1 = generate_dataset_hash_from_config(cfg, _datasets(), tokenizer_name)
|
||||
# Changing tokenizer_name still changes the hash
|
||||
h2 = generate_dataset_hash_from_config(cfg, _datasets(), "some/other-model")
|
||||
|
||||
assert h1 != h2
|
||||
119
tests/utils/schemas/validation/test_config_validators.py
Normal file
119
tests/utils/schemas/validation/test_config_validators.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Tests for new config validators added to AxolotlInputConfig.
|
||||
|
||||
Covers:
|
||||
- save_strategy: 'best' requires metric_for_best_model
|
||||
- streaming=True with val_set_size > 0 is rejected
|
||||
- lora_target_modules with invalid regex patterns is rejected
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestSaveStrategyBestValidator:
|
||||
"""save_strategy: 'best' must be accompanied by metric_for_best_model."""
|
||||
|
||||
def test_save_strategy_best_without_metric_raises(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(save_strategy="best")
|
||||
with pytest.raises(ValueError, match="metric_for_best_model"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_save_strategy_best_with_metric_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
save_strategy="best",
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
validated = validate_config(cfg)
|
||||
assert validated.save_strategy == "best"
|
||||
assert validated.metric_for_best_model == "eval_loss"
|
||||
|
||||
def test_save_strategy_epoch_without_metric_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(save_strategy="epoch")
|
||||
validated = validate_config(cfg)
|
||||
assert validated.save_strategy == "epoch"
|
||||
|
||||
def test_save_strategy_no_without_metric_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(save_strategy="no")
|
||||
validated = validate_config(cfg)
|
||||
assert validated.save_strategy == "no"
|
||||
|
||||
def test_save_strategy_unset_without_metric_passes(self, min_base_cfg):
|
||||
"""The default (None / not set) should not require metric_for_best_model."""
|
||||
validated = validate_config(min_base_cfg)
|
||||
assert validated.save_strategy is None
|
||||
|
||||
|
||||
class TestStreamingWithValSetSizeValidator:
|
||||
"""streaming=True is incompatible with val_set_size > 0."""
|
||||
|
||||
def test_streaming_with_val_set_size_raises(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
streaming=True, val_set_size=0.1, max_steps=100
|
||||
)
|
||||
with pytest.raises(ValueError, match="val_set_size"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_streaming_with_val_set_size_zero_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
streaming=True, val_set_size=0.0, max_steps=100
|
||||
)
|
||||
validated = validate_config(cfg)
|
||||
assert validated.streaming is True
|
||||
|
||||
def test_streaming_false_with_val_set_size_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(streaming=False, val_set_size=0.1)
|
||||
validated = validate_config(cfg)
|
||||
assert validated.val_set_size == pytest.approx(0.1)
|
||||
|
||||
def test_streaming_unset_with_val_set_size_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(val_set_size=0.2)
|
||||
validated = validate_config(cfg)
|
||||
assert validated.val_set_size == pytest.approx(0.2)
|
||||
|
||||
|
||||
class TestLoraTargetModulesRegexValidator:
|
||||
"""lora_target_modules entries must be valid Python regex patterns."""
|
||||
|
||||
def test_invalid_regex_raises(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_modules=["q_proj", "[invalid_regex"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||
validate_config(cfg)
|
||||
|
||||
def test_valid_regex_passes(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_modules=["q_proj", "v_proj", r".*_proj"],
|
||||
)
|
||||
validated = validate_config(cfg)
|
||||
assert "q_proj" in validated.lora_target_modules
|
||||
|
||||
def test_plain_module_names_pass(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||
)
|
||||
validated = validate_config(cfg)
|
||||
assert len(validated.lora_target_modules) == 4
|
||||
|
||||
def test_lora_target_linear_string_not_validated(self, min_base_cfg):
|
||||
"""When lora_target_modules is a string (e.g. 'all-linear'), skip regex check."""
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_modules="all-linear",
|
||||
)
|
||||
# Should not raise
|
||||
validate_config(cfg)
|
||||
|
||||
def test_multiple_invalid_patterns_reported(self, min_base_cfg):
|
||||
cfg = min_base_cfg | DictDefault(
|
||||
adapter="lora",
|
||||
lora_target_modules=["[bad1", "[bad2"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="invalid regex pattern"):
|
||||
validate_config(cfg)
|
||||
Reference in New Issue
Block a user