diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 58b4d75bd..83db96750 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -29,6 +29,7 @@ def default(cfg, dataset_idx=0, **kwargs): "user": ["user"], "assistant": ["assistant"], "system": ["system"], + "tool": ["tool"], }, ) role_map = {} @@ -174,6 +175,7 @@ def argilla_chat(cfg, dataset_idx=0, **kwargs): "user": ["user"], "assistant": ["assistant"], "system": ["system"], + "tool": ["tool"], }, ) role_map = {} diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 4e6aa1ea3..525e0e7ff 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -516,12 +516,21 @@ def generate_dataset_hash_from_config( Returns: MD5 hash string representing the configuration. """ + # When added_tokens_overrides is set, tokenizer.name_or_path contains output_dir. + # Use the canonical tokenizer config + overrides content so the hash is stable across output_dir changes. + if cfg.get("added_tokens_overrides"): + tokenizer_fingerprint = f"{cfg.tokenizer_config}+overrides:" + ",".join( + f"{k}={v}" for k, v in sorted(cfg.added_tokens_overrides.items()) + ) + else: + tokenizer_fingerprint = tokenizer_name + config_str = ( f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@" f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@" f"{cfg.dataset_exact_deduplication or False}|" f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}" - f"|{tokenizer_name}" + f"|{tokenizer_fingerprint}" ) return str(md5(config_str)) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e45534640..d0f588d9b 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1,5 +1,6 @@ """Module with Pydantic models for configuration.""" +import re from typing import Annotated, Any, Literal from accelerate.utils import is_fp8_available @@ -1338,6 +1339,39 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_save_strategy_best_requires_metric(cls, data): + if data.get("save_strategy") == "best" and not data.get( + "metric_for_best_model" + ): + raise ValueError( + "save_strategy: 'best' requires metric_for_best_model to be set. " + "Please specify the metric to use, e.g. metric_for_best_model: eval_loss" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_lora_target_modules_regex(cls, data): + lora_target_modules = data.get("lora_target_modules") + if not isinstance(lora_target_modules, list): + return data + invalid = [] + for pattern in lora_target_modules: + if not isinstance(pattern, str): + continue + try: + re.compile(pattern) + except re.error: + invalid.append(pattern) + if invalid: + raise ValueError( + f"lora_target_modules contains invalid regex pattern(s): {invalid}. " + "Please provide valid Python regex patterns or plain module name strings." + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """Wrapper to valdiate GPU capabilities with the configured options""" diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index 72766b5ce..74c98204c 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -294,5 +294,88 @@ class TestArgillaChatDPOChatTemplate: assert result["rejected"] == "party on<|end|>" +class TestDPOChatTemplateToolRole: + """ + Test that DPO chat template strategy handles tool role messages without KeyError. + Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3217 + """ + + def test_tool_role_default_no_key_error(self, llama3_tokenizer): + """Messages list with a 'tool' role should not raise KeyError.""" + dataset = Dataset.from_list( + [ + { + "messages": [ + {"role": "user", "content": "What is the weather?"}, + { + "role": "assistant", + "content": "Let me check.", + }, + { + "role": "tool", + "content": "22°C, sunny.", + }, + ], + "chosen": { + "role": "assistant", + "content": "It is 22°C and sunny.", + }, + "rejected": { + "role": "assistant", + "content": "I don't know.", + }, + } + ] + ) + transform_fn, _ = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [{"type": "chat_template"}], + } + ) + ) + # Should not raise KeyError: 'tool' + result = transform_fn(dataset[0], tokenizer=llama3_tokenizer) + assert "prompt" in result + assert "chosen" in result + assert "rejected" in result + + def test_tool_role_custom_mapping_preserved(self, llama3_tokenizer): + """A user-supplied roles mapping that overrides 'tool' is still respected.""" + dataset = Dataset.from_list( + [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "tool_result", "content": "42"}, + ], + "chosen": {"role": "assistant", "content": "The answer is 42."}, + "rejected": {"role": "assistant", "content": "Unknown."}, + } + ] + ) + transform_fn, _ = default( + DictDefault( + { + "chat_template": "llama3", + "datasets": [ + { + "type": "chat_template", + "roles": { + "user": ["user"], + "assistant": ["assistant"], + "system": ["system"], + "tool": ["tool_result"], + }, + } + ], + } + ) + ) + result = transform_fn(dataset[0], tokenizer=llama3_tokenizer) + assert "prompt" in result + + if __name__ == "__main__": unittest.main() diff --git a/tests/utils/data/test_hash.py b/tests/utils/data/test_hash.py new file mode 100644 index 000000000..04fed468e --- /dev/null +++ b/tests/utils/data/test_hash.py @@ -0,0 +1,135 @@ +""" +Tests for generate_dataset_hash_from_config. + +Regression test for https://github.com/axolotl-ai-cloud/axolotl/issues/3303: +changing output_dir should not bust the dataset cache when added_tokens_overrides +is set. +""" + +from axolotl.utils.data.shared import generate_dataset_hash_from_config +from axolotl.utils.dict import DictDefault + + +def _base_cfg(**kwargs): + return DictDefault( + { + "sequence_len": 2048, + "sample_packing": False, + "eval_sample_packing": False, + "group_by_length": False, + "kd_temperature": None, + "dataset_exact_deduplication": False, + "tokenizer_config": "NousResearch/Llama-3.2-1B", + **kwargs, + } + ) + + +def _datasets(): + return [ + DictDefault( + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + "shards": None, + "conversation": None, + "split": "train", + "temperature": None, + } + ) + ] + + +class TestGenerateDatasetHashFromConfig: + def test_same_config_same_hash(self): + """Identical configs produce identical hashes.""" + cfg = _base_cfg() + h1 = generate_dataset_hash_from_config( + cfg, _datasets(), "NousResearch/Llama-3.2-1B" + ) + h2 = generate_dataset_hash_from_config( + cfg, _datasets(), "NousResearch/Llama-3.2-1B" + ) + assert h1 == h2 + + def test_different_tokenizer_different_hash(self): + """A different tokenizer path produces a different hash.""" + cfg = _base_cfg() + h1 = generate_dataset_hash_from_config( + cfg, _datasets(), "NousResearch/Llama-3.2-1B" + ) + h2 = generate_dataset_hash_from_config( + cfg, _datasets(), "HuggingFaceTB/SmolLM2-135M" + ) + assert h1 != h2 + + def test_different_sequence_len_different_hash(self): + cfg_a = _base_cfg(sequence_len=2048) + cfg_b = _base_cfg(sequence_len=4096) + h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), "tok") + h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), "tok") + assert h1 != h2 + + # --- Regression: added_tokens_overrides + output_dir --- + + def test_added_tokens_overrides_hash_stable_across_output_dir(self): + """Hash must not change when only output_dir changes (issue #3303). + + When added_tokens_overrides is set the tokenizer is saved into output_dir, + making tokenizer.name_or_path an absolute path that includes output_dir. + The hash should be derived from the canonical tokenizer config + overrides, + not from the output-dir-dependent path. + """ + cfg_run1 = _base_cfg( + output_dir="/tmp/run_1", + added_tokens_overrides={32000: "", 32001: ""}, + ) + cfg_run2 = _base_cfg( + output_dir="/tmp/run_2_different_name", + added_tokens_overrides={32000: "", 32001: ""}, + ) + + # Simulate what happens in practice: tokenizer.name_or_path becomes the + # output_dir-based path after modify_tokenizer_files() saves the tokenizer. + tokenizer_name_run1 = "/tmp/run_1/modified_tokenizer" + tokenizer_name_run2 = "/tmp/run_2_different_name/modified_tokenizer" + + h1 = generate_dataset_hash_from_config( + cfg_run1, _datasets(), tokenizer_name_run1 + ) + h2 = generate_dataset_hash_from_config( + cfg_run2, _datasets(), tokenizer_name_run2 + ) + + assert h1 == h2, ( + "Dataset cache hash must not change when only output_dir changes " + "while added_tokens_overrides stays the same (issue #3303)." + ) + + def test_added_tokens_overrides_different_overrides_different_hash(self): + """Different added_tokens_overrides produce different hashes.""" + cfg_a = _base_cfg( + output_dir="/tmp/run_a", + added_tokens_overrides={32000: ""}, + ) + cfg_b = _base_cfg( + output_dir="/tmp/run_a", # same output_dir + added_tokens_overrides={32000: ""}, + ) + tokenizer_path = "/tmp/run_a/modified_tokenizer" + + h1 = generate_dataset_hash_from_config(cfg_a, _datasets(), tokenizer_path) + h2 = generate_dataset_hash_from_config(cfg_b, _datasets(), tokenizer_path) + + assert h1 != h2 + + def test_no_added_tokens_overrides_uses_tokenizer_name_as_before(self): + """Without added_tokens_overrides the old behaviour is preserved.""" + cfg = _base_cfg() # no added_tokens_overrides + tokenizer_name = "NousResearch/Llama-3.2-1B" + + h1 = generate_dataset_hash_from_config(cfg, _datasets(), tokenizer_name) + # Changing tokenizer_name still changes the hash + h2 = generate_dataset_hash_from_config(cfg, _datasets(), "some/other-model") + + assert h1 != h2 diff --git a/tests/utils/schemas/validation/test_config_validators.py b/tests/utils/schemas/validation/test_config_validators.py new file mode 100644 index 000000000..c756f1362 --- /dev/null +++ b/tests/utils/schemas/validation/test_config_validators.py @@ -0,0 +1,119 @@ +""" +Tests for new config validators added to AxolotlInputConfig. + +Covers: + - save_strategy: 'best' requires metric_for_best_model + - streaming=True with val_set_size > 0 is rejected + - lora_target_modules with invalid regex patterns is rejected +""" + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.dict import DictDefault + + +class TestSaveStrategyBestValidator: + """save_strategy: 'best' must be accompanied by metric_for_best_model.""" + + def test_save_strategy_best_without_metric_raises(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(save_strategy="best") + with pytest.raises(ValueError, match="metric_for_best_model"): + validate_config(cfg) + + def test_save_strategy_best_with_metric_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + save_strategy="best", + metric_for_best_model="eval_loss", + ) + validated = validate_config(cfg) + assert validated.save_strategy == "best" + assert validated.metric_for_best_model == "eval_loss" + + def test_save_strategy_epoch_without_metric_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(save_strategy="epoch") + validated = validate_config(cfg) + assert validated.save_strategy == "epoch" + + def test_save_strategy_no_without_metric_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(save_strategy="no") + validated = validate_config(cfg) + assert validated.save_strategy == "no" + + def test_save_strategy_unset_without_metric_passes(self, min_base_cfg): + """The default (None / not set) should not require metric_for_best_model.""" + validated = validate_config(min_base_cfg) + assert validated.save_strategy is None + + +class TestStreamingWithValSetSizeValidator: + """streaming=True is incompatible with val_set_size > 0.""" + + def test_streaming_with_val_set_size_raises(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + streaming=True, val_set_size=0.1, max_steps=100 + ) + with pytest.raises(ValueError, match="val_set_size"): + validate_config(cfg) + + def test_streaming_with_val_set_size_zero_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + streaming=True, val_set_size=0.0, max_steps=100 + ) + validated = validate_config(cfg) + assert validated.streaming is True + + def test_streaming_false_with_val_set_size_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(streaming=False, val_set_size=0.1) + validated = validate_config(cfg) + assert validated.val_set_size == pytest.approx(0.1) + + def test_streaming_unset_with_val_set_size_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault(val_set_size=0.2) + validated = validate_config(cfg) + assert validated.val_set_size == pytest.approx(0.2) + + +class TestLoraTargetModulesRegexValidator: + """lora_target_modules entries must be valid Python regex patterns.""" + + def test_invalid_regex_raises(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + adapter="lora", + lora_target_modules=["q_proj", "[invalid_regex"], + ) + with pytest.raises(ValueError, match="invalid regex pattern"): + validate_config(cfg) + + def test_valid_regex_passes(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + adapter="lora", + lora_target_modules=["q_proj", "v_proj", r".*_proj"], + ) + validated = validate_config(cfg) + assert "q_proj" in validated.lora_target_modules + + def test_plain_module_names_pass(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + adapter="lora", + lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + validated = validate_config(cfg) + assert len(validated.lora_target_modules) == 4 + + def test_lora_target_linear_string_not_validated(self, min_base_cfg): + """When lora_target_modules is a string (e.g. 'all-linear'), skip regex check.""" + cfg = min_base_cfg | DictDefault( + adapter="lora", + lora_target_modules="all-linear", + ) + # Should not raise + validate_config(cfg) + + def test_multiple_invalid_patterns_reported(self, min_base_cfg): + cfg = min_base_cfg | DictDefault( + adapter="lora", + lora_target_modules=["[bad1", "[bad2"], + ) + with pytest.raises(ValueError, match="invalid regex pattern"): + validate_config(cfg)