fix: DPO tool role KeyError (#3217), dataset hash output_dir (#3303), config validators (#3538) [skip ci]

* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e]

- Add 'tool' to default role_map_inv in dpo/chat_template.py default() and
  argilla_chat() so datasets with tool-call messages no longer raise
  KeyError: 'tool' (closes #3217)

- Fix generate_dataset_hash_from_config to use canonical tokenizer config +
  overrides content instead of tokenizer.name_or_path when added_tokens_overrides
  is set, preventing cache busting when only output_dir changes (closes #3303)

- Add three Pydantic config validators to AxolotlConfigWCapabilities:
  * save_strategy: 'best' requires metric_for_best_model
  * streaming=True is incompatible with val_set_size > 0
  * lora_target_modules list entries must be valid Python regex patterns

- Tests for all three changes

* review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash

* chore: lint

* move the validators out of the w/ capabilities schema

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Edward Zion Saji
2026-04-02 05:27:07 +05:30
committed by GitHub
parent c92b71bd0c
commit 55a7950e3d
6 changed files with 383 additions and 1 deletions

View File

@@ -29,6 +29,7 @@ def default(cfg, dataset_idx=0, **kwargs):
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
"tool": ["tool"],
},
)
role_map = {}
@@ -174,6 +175,7 @@ def argilla_chat(cfg, dataset_idx=0, **kwargs):
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
"tool": ["tool"],
},
)
role_map = {}

View File

@@ -516,12 +516,21 @@ def generate_dataset_hash_from_config(
Returns:
MD5 hash string representing the configuration.
"""
# When added_tokens_overrides is set, tokenizer.name_or_path contains output_dir.
# Use the canonical tokenizer config + overrides content so the hash is stable across output_dir changes.
if cfg.get("added_tokens_overrides"):
tokenizer_fingerprint = f"{cfg.tokenizer_config}+overrides:" + ",".join(
f"{k}={v}" for k, v in sorted(cfg.added_tokens_overrides.items())
)
else:
tokenizer_fingerprint = tokenizer_name
config_str = (
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
f"{cfg.dataset_exact_deduplication or False}|"
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
f"|{tokenizer_name}"
f"|{tokenizer_fingerprint}"
)
return str(md5(config_str))

View File

@@ -1,5 +1,6 @@
"""Module with Pydantic models for configuration."""
import re
from typing import Annotated, Any, Literal
from accelerate.utils import is_fp8_available
@@ -1338,6 +1339,39 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_save_strategy_best_requires_metric(cls, data):
if data.get("save_strategy") == "best" and not data.get(
"metric_for_best_model"
):
raise ValueError(
"save_strategy: 'best' requires metric_for_best_model to be set. "
"Please specify the metric to use, e.g. metric_for_best_model: eval_loss"
)
return data
@model_validator(mode="before")
@classmethod
def check_lora_target_modules_regex(cls, data):
lora_target_modules = data.get("lora_target_modules")
if not isinstance(lora_target_modules, list):
return data
invalid = []
for pattern in lora_target_modules:
if not isinstance(pattern, str):
continue
try:
re.compile(pattern)
except re.error:
invalid.append(pattern)
if invalid:
raise ValueError(
f"lora_target_modules contains invalid regex pattern(s): {invalid}. "
"Please provide valid Python regex patterns or plain module name strings."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""Wrapper to valdiate GPU capabilities with the configured options"""