Files
axolotl/src/axolotl/utils/schemas/datasets.py
Wing Lian 2d77165dc0 automatically split out reasoning trace from dataset (#2579)
* automatically split out reasoning trace from dataset

* chore: lint

* fix import
2025-04-28 18:23:03 -04:00

166 lines
5.3 KiB
Python

"""Pydantic models for datasets-related configuration"""
from pydantic import BaseModel, model_validator
from axolotl.utils.schemas.enums import ChatTemplate
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
class UserDefinedPrompterType(BaseModel):
"""Structure for user defined prompt types"""
system_prompt: str | None = None
system_format: str | None = None
field_system: str | None = None
field_instruction: str | None = None
field_input: str | None = None
field_output: str | None = None
format: str | None = None
no_input_format: str | None = None
field: str | None = None
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: str | None = None
split: str | None = None
type: str | UserDefinedPrompterType | None = None
input_transform: str | None = None
shards: int | None = None
shards_idx: int | None = None
preprocess_shards: int | None = None
conversation: str | None = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: ChatTemplate | str | None = None
chat_template_jinja: str | None = None
data_files: str | list[str] | None = None
input_format: str | None = None
name: str | None = None
ds_type: str | None = None
field: str | None = None
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings
message_field_content: str | None = None
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
split_thinking: bool | None = None
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None
train_on_eos: str | None = None
roles: dict[str, list[str]] | None = None
drop_system_message: bool | None = None
trust_remote_code: bool | None = False
revision: str | None = None
@model_validator(mode="before")
@classmethod
def handle_legacy_message_fields(cls, data):
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
return handle_legacy_message_fields_logic(data)
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code
def check_chat_template_config(cls, data):
if isinstance(data, BaseModel):
data = data.model_dump()
# Set chat_template to tokenizer_default if not set
if data.get("type") == "chat_template" and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.tokenizer_default
# if chat_template is set to jinja, chat_template_jinja is required
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
"chat_template_jinja"
):
raise ValueError(
"chat_template_jinja is required when chat_template is set to jinja"
)
# If chat_template_jinja is set, set chat_template to jinja
if data.get("chat_template_jinja") and not data.get("chat_template"):
data["chat_template"] = ChatTemplate.jinja
return data
class PretrainingDataset(BaseModel):
"""Pretraining dataset configuration subset"""
name: str | None = None
path: str | None = None
split: str | None = "train"
text_column: str | None = "text"
type: str | None = "pretrain"
trust_remote_code: bool | None = False
data_files: str | None = None
skip: int | None = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
field_system: str | None = None
field_prompt: str | None = None
field_chosen: str | None = None
field_rejected: str | None = None
prompt_format: str | None = None
chosen_format: str | None = None
rejected_format: str | None = None
class DPODataset(BaseModel):
"""DPO configuration subset"""
path: str | None = None
split: str | None = None
type: UserDefinedDPOType | str | None = None
data_files: list[str] | None = None
revision: str | None = None
field_messages: str | None = None
class StepwiseSupervisedDataset(BaseModel):
"""Stepwise supervised dataset configuration subset"""
path: str | None = None
split: str | None = None
data_files: list[str] | None = None
revision: str | None = None
step_separator: str | None = None
max_completion_length: int | None = None
train_on_last_step_only: bool | None = None
class UserDefinedKTOType(BaseModel):
"""User defined typing for KTO"""
field_system: str | None = None
field_prompt: str | None = None
field_completion: str | None = None
field_label: bool | None = None
prompt_format: str | None = None
completion_format: str | None = None
class KTODataset(BaseModel):
"""KTO configuration subset"""
path: str | None = None
split: str | None = None
type: UserDefinedKTOType | str | None = None
data_files: list[str] | None = None
trust_remote_code: bool | None = False
revision: str | None = None
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset