fix: DPO tool role KeyError (#3217), dataset hash output_dir (#3303), config validators (#3538) [skip ci]
* fix: DPO tool role KeyError, dataset hash output_dir, config validators [skip-e2e] - Add 'tool' to default role_map_inv in dpo/chat_template.py default() and argilla_chat() so datasets with tool-call messages no longer raise KeyError: 'tool' (closes #3217) - Fix generate_dataset_hash_from_config to use canonical tokenizer config + overrides content instead of tokenizer.name_or_path when added_tokens_overrides is set, preventing cache busting when only output_dir changes (closes #3303) - Add three Pydantic config validators to AxolotlConfigWCapabilities: * save_strategy: 'best' requires metric_for_best_model * streaming=True is incompatible with val_set_size > 0 * lora_target_modules list entries must be valid Python regex patterns - Tests for all three changes * review: condense comment in shared.py, swap Mistral model for SmolLM2-135M in test_hash * chore: lint * move the validators out of the w/ capabilities schema --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -29,6 +29,7 @@ def default(cfg, dataset_idx=0, **kwargs):
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
"tool": ["tool"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
@@ -174,6 +175,7 @@ def argilla_chat(cfg, dataset_idx=0, **kwargs):
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
"system": ["system"],
|
||||
"tool": ["tool"],
|
||||
},
|
||||
)
|
||||
role_map = {}
|
||||
|
||||
@@ -516,12 +516,21 @@ def generate_dataset_hash_from_config(
|
||||
Returns:
|
||||
MD5 hash string representing the configuration.
|
||||
"""
|
||||
# When added_tokens_overrides is set, tokenizer.name_or_path contains output_dir.
|
||||
# Use the canonical tokenizer config + overrides content so the hash is stable across output_dir changes.
|
||||
if cfg.get("added_tokens_overrides"):
|
||||
tokenizer_fingerprint = f"{cfg.tokenizer_config}+overrides:" + ",".join(
|
||||
f"{k}={v}" for k, v in sorted(cfg.added_tokens_overrides.items())
|
||||
)
|
||||
else:
|
||||
tokenizer_fingerprint = tokenizer_name
|
||||
|
||||
config_str = (
|
||||
f"{cfg.sequence_len}@{cfg.sample_packing}@{cfg.eval_sample_packing}@"
|
||||
f"{cfg.group_by_length}@{cfg.kd_temperature or 1.0}@"
|
||||
f"{cfg.dataset_exact_deduplication or False}|"
|
||||
f"{'|'.join(sorted([f'{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}' for d in cfg_datasets]))}"
|
||||
f"|{tokenizer_name}"
|
||||
f"|{tokenizer_fingerprint}"
|
||||
)
|
||||
return str(md5(config_str))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
import re
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from accelerate.utils import is_fp8_available
|
||||
@@ -1338,6 +1339,39 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_save_strategy_best_requires_metric(cls, data):
|
||||
if data.get("save_strategy") == "best" and not data.get(
|
||||
"metric_for_best_model"
|
||||
):
|
||||
raise ValueError(
|
||||
"save_strategy: 'best' requires metric_for_best_model to be set. "
|
||||
"Please specify the metric to use, e.g. metric_for_best_model: eval_loss"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_lora_target_modules_regex(cls, data):
|
||||
lora_target_modules = data.get("lora_target_modules")
|
||||
if not isinstance(lora_target_modules, list):
|
||||
return data
|
||||
invalid = []
|
||||
for pattern in lora_target_modules:
|
||||
if not isinstance(pattern, str):
|
||||
continue
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error:
|
||||
invalid.append(pattern)
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"lora_target_modules contains invalid regex pattern(s): {invalid}. "
|
||||
"Please provide valid Python regex patterns or plain module name strings."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""Wrapper to valdiate GPU capabilities with the configured options"""
|
||||
|
||||
Reference in New Issue
Block a user