166 lines
5.3 KiB
Python
166 lines
5.3 KiB
Python
"""Pydantic models for datasets-related configuration"""
|
|
|
|
from pydantic import BaseModel, model_validator
|
|
|
|
from axolotl.utils.schemas.enums import ChatTemplate
|
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
|
|
|
|
|
class UserDefinedPrompterType(BaseModel):
|
|
"""Structure for user defined prompt types"""
|
|
|
|
system_prompt: str | None = None
|
|
system_format: str | None = None
|
|
field_system: str | None = None
|
|
field_instruction: str | None = None
|
|
field_input: str | None = None
|
|
field_output: str | None = None
|
|
|
|
format: str | None = None
|
|
no_input_format: str | None = None
|
|
field: str | None = None
|
|
|
|
|
|
class SFTDataset(BaseModel):
|
|
"""SFT configuration subset"""
|
|
|
|
path: str | None = None
|
|
split: str | None = None
|
|
type: str | UserDefinedPrompterType | None = None
|
|
input_transform: str | None = None
|
|
shards: int | None = None
|
|
shards_idx: int | None = None
|
|
preprocess_shards: int | None = None
|
|
conversation: str | None = None
|
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
|
chat_template: ChatTemplate | str | None = None
|
|
chat_template_jinja: str | None = None
|
|
data_files: str | list[str] | None = None
|
|
input_format: str | None = None
|
|
name: str | None = None
|
|
ds_type: str | None = None
|
|
field: str | None = None
|
|
field_human: str | None = None
|
|
field_model: str | None = None
|
|
field_messages: str | None = None
|
|
# deprecated, use message_property_mappings
|
|
message_field_role: str | None = None
|
|
# deprecated, use message_property_mappings
|
|
message_field_content: str | None = None
|
|
message_property_mappings: dict[str, str] | None = None
|
|
message_field_training: str | None = None
|
|
message_field_training_detail: str | None = None
|
|
split_thinking: bool | None = None
|
|
logprobs_field: str | None = None
|
|
temperature: float | None = None
|
|
roles_to_train: list[str] | None = None
|
|
train_on_eos: str | None = None
|
|
roles: dict[str, list[str]] | None = None
|
|
drop_system_message: bool | None = None
|
|
trust_remote_code: bool | None = False
|
|
revision: str | None = None
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
def handle_legacy_message_fields(cls, data):
|
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
|
return handle_legacy_message_fields_logic(data)
|
|
|
|
@model_validator(mode="before")
|
|
@classmethod
|
|
# pylint: disable=duplicate-code
|
|
def check_chat_template_config(cls, data):
|
|
if isinstance(data, BaseModel):
|
|
data = data.model_dump()
|
|
|
|
# Set chat_template to tokenizer_default if not set
|
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
|
|
|
# if chat_template is set to jinja, chat_template_jinja is required
|
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
|
"chat_template_jinja"
|
|
):
|
|
raise ValueError(
|
|
"chat_template_jinja is required when chat_template is set to jinja"
|
|
)
|
|
|
|
# If chat_template_jinja is set, set chat_template to jinja
|
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
|
data["chat_template"] = ChatTemplate.jinja
|
|
|
|
return data
|
|
|
|
|
|
class PretrainingDataset(BaseModel):
|
|
"""Pretraining dataset configuration subset"""
|
|
|
|
name: str | None = None
|
|
path: str | None = None
|
|
split: str | None = "train"
|
|
text_column: str | None = "text"
|
|
type: str | None = "pretrain"
|
|
trust_remote_code: bool | None = False
|
|
data_files: str | None = None
|
|
skip: int | None = None
|
|
|
|
|
|
class UserDefinedDPOType(BaseModel):
|
|
"""User defined typing for DPO"""
|
|
|
|
field_system: str | None = None
|
|
field_prompt: str | None = None
|
|
field_chosen: str | None = None
|
|
field_rejected: str | None = None
|
|
prompt_format: str | None = None
|
|
chosen_format: str | None = None
|
|
rejected_format: str | None = None
|
|
|
|
|
|
class DPODataset(BaseModel):
|
|
"""DPO configuration subset"""
|
|
|
|
path: str | None = None
|
|
split: str | None = None
|
|
type: UserDefinedDPOType | str | None = None
|
|
data_files: list[str] | None = None
|
|
revision: str | None = None
|
|
field_messages: str | None = None
|
|
|
|
|
|
class StepwiseSupervisedDataset(BaseModel):
|
|
"""Stepwise supervised dataset configuration subset"""
|
|
|
|
path: str | None = None
|
|
split: str | None = None
|
|
data_files: list[str] | None = None
|
|
revision: str | None = None
|
|
step_separator: str | None = None
|
|
max_completion_length: int | None = None
|
|
train_on_last_step_only: bool | None = None
|
|
|
|
|
|
class UserDefinedKTOType(BaseModel):
|
|
"""User defined typing for KTO"""
|
|
|
|
field_system: str | None = None
|
|
field_prompt: str | None = None
|
|
field_completion: str | None = None
|
|
field_label: bool | None = None
|
|
prompt_format: str | None = None
|
|
completion_format: str | None = None
|
|
|
|
|
|
class KTODataset(BaseModel):
|
|
"""KTO configuration subset"""
|
|
|
|
path: str | None = None
|
|
split: str | None = None
|
|
type: UserDefinedKTOType | str | None = None
|
|
data_files: list[str] | None = None
|
|
trust_remote_code: bool | None = False
|
|
revision: str | None = None
|
|
|
|
|
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|