Config doc autogen (#2718)
* config reference doc autogen * improvements * cleanup; still ugly but working * reformat * remove autogen config ref from git * factor out validations * rewrite * rewrite * cleanup * progress * progress * progress * lint and minifying somewhat * remove unneeded * coderabbit * coderabbit * update preview-docs workflow triggers * installing with deps * coderabbit * update refs * overwrote file accidentally
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,8 @@
|
||||
"""Pydantic models for datasets-related configuration"""
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import ChatTemplate
|
||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
@@ -9,57 +11,178 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
"""Structure for user defined prompt types"""
|
||||
|
||||
system_prompt: str | None = None
|
||||
system_format: str | None = None
|
||||
system_prompt: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Custom user instruction prompt"},
|
||||
)
|
||||
system_format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Use {system} as key to be replaced"},
|
||||
)
|
||||
field_system: str | None = None
|
||||
field_instruction: str | None = None
|
||||
field_input: str | None = None
|
||||
field_output: str | None = None
|
||||
|
||||
format: str | None = None
|
||||
no_input_format: str | None = None
|
||||
field: str | None = None
|
||||
format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}"
|
||||
},
|
||||
)
|
||||
no_input_format: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "'no_input_format' cannot include {input}"},
|
||||
)
|
||||
field: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "For `completion` datsets only, uses the provided field instead of `text` column"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SFTDataset(BaseModel):
|
||||
"""SFT configuration subset"""
|
||||
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
type: str | UserDefinedPrompterType | None = None
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory"
|
||||
},
|
||||
)
|
||||
split: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "name of dataset split to load from"},
|
||||
)
|
||||
type: str | UserDefinedPrompterType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]"
|
||||
},
|
||||
)
|
||||
input_transform: str | None = None
|
||||
shards: int | None = None
|
||||
shards_idx: int | None = None
|
||||
preprocess_shards: int | None = None
|
||||
shards: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "split dataset into N pieces (use with shards_idx)"
|
||||
},
|
||||
)
|
||||
shards_idx: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "the index of sharded dataset to use"},
|
||||
)
|
||||
preprocess_shards: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)"
|
||||
},
|
||||
)
|
||||
conversation: str | None = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: ChatTemplate | str | None = None
|
||||
chat_template_jinja: str | None = None
|
||||
data_files: str | list[str] | None = None
|
||||
chat_template: ChatTemplate | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field."
|
||||
},
|
||||
)
|
||||
chat_template_jinja: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty."
|
||||
},
|
||||
)
|
||||
data_files: str | list[str] | None = Field(
|
||||
default=None, json_schema_extra={"description": "path to source data files"}
|
||||
)
|
||||
input_format: str | None = None
|
||||
name: str | None = None
|
||||
ds_type: str | None = None
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "name of dataset configuration to load"},
|
||||
)
|
||||
ds_type: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "defines the datatype when path is a file"},
|
||||
)
|
||||
field: str | None = None
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
field_messages: str | None = None
|
||||
field_tools: str | None = None
|
||||
field_messages: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the messages (default: "messages")'
|
||||
},
|
||||
)
|
||||
field_tools: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
|
||||
},
|
||||
)
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_content: str | None = None
|
||||
message_property_mappings: dict[str, str] | None = None
|
||||
message_field_training: str | None = None
|
||||
message_field_training_detail: str | None = None
|
||||
split_thinking: bool | None = None
|
||||
message_property_mappings: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template."
|
||||
},
|
||||
)
|
||||
message_field_training: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`."
|
||||
},
|
||||
)
|
||||
message_field_training_detail: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)."
|
||||
},
|
||||
)
|
||||
split_thinking: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags"
|
||||
},
|
||||
)
|
||||
logprobs_field: str | None = None
|
||||
temperature: float | None = None
|
||||
roles_to_train: list[str] | None = None
|
||||
train_on_eos: str | None = None
|
||||
roles: dict[str, list[str]] | None = None
|
||||
drop_system_message: bool | None = None
|
||||
trust_remote_code: bool | None = False
|
||||
revision: str | None = None
|
||||
roles_to_train: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
|
||||
},
|
||||
)
|
||||
train_on_eos: Literal["all", "turn", "last"] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
|
||||
},
|
||||
)
|
||||
roles: dict[str, list[str]] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]'
|
||||
},
|
||||
)
|
||||
drop_system_message: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content."
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||
)
|
||||
revision: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets."
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -60,10 +60,30 @@ class RemappedParameters(BaseModel):
|
||||
"""Parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||
default=None, alias="model_config"
|
||||
default=None,
|
||||
alias="model_config",
|
||||
json_schema_extra={
|
||||
"description": "optional overrides to the base model configuration"
|
||||
},
|
||||
)
|
||||
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||
default=None, alias="model_kwargs"
|
||||
default=None,
|
||||
alias="model_kwargs",
|
||||
json_schema_extra={
|
||||
"description": "optional overrides the base model loading from_pretrained"
|
||||
},
|
||||
)
|
||||
type_of_model: str | None = Field(
|
||||
default=None,
|
||||
alias="model_type",
|
||||
json_schema_extra={
|
||||
"description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too"
|
||||
},
|
||||
)
|
||||
revision_of_model: str | None = Field(
|
||||
default=None,
|
||||
alias="model_revision",
|
||||
json_schema_extra={
|
||||
"description": "You can specify to choose a specific model revision from huggingface hub"
|
||||
},
|
||||
)
|
||||
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Enums for Axolotl input config"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
@@ -8,81 +10,81 @@ import torch
|
||||
class TorchIntDType(Enum):
|
||||
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
|
||||
|
||||
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
|
||||
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
|
||||
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
|
||||
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
|
||||
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
|
||||
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
|
||||
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
|
||||
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
|
||||
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
|
||||
uint1 = getattr(torch, "uint1", None)
|
||||
uint2 = getattr(torch, "uint2", None)
|
||||
uint3 = getattr(torch, "uint3", None)
|
||||
uint4 = getattr(torch, "uint4", None)
|
||||
uint5 = getattr(torch, "uint5", None)
|
||||
uint6 = getattr(torch, "uint6", None)
|
||||
uint7 = getattr(torch, "uint7", None)
|
||||
int4 = getattr(torch, "int4", None)
|
||||
int8 = getattr(torch, "int8", None)
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
DPO = "dpo" # pylint: disable=invalid-name
|
||||
GRPO = "grpo" # pylint: disable=invalid-name
|
||||
IPO = "ipo" # pylint: disable=invalid-name
|
||||
ORPO = "orpo" # pylint: disable=invalid-name
|
||||
KTO = "kto" # pylint: disable=invalid-name
|
||||
SIMPO = "simpo" # pylint: disable=invalid-name
|
||||
DPO = "dpo"
|
||||
GRPO = "grpo"
|
||||
IPO = "ipo"
|
||||
ORPO = "orpo"
|
||||
KTO = "kto"
|
||||
SIMPO = "simpo"
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
llama4 = "llama4" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
qwen3 = "qwen3" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
pixtral = "pixtral" # pylint: disable=invalid-name
|
||||
llava = "llava" # pylint: disable=invalid-name
|
||||
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
|
||||
gemma3 = "gemma3" # pylint: disable=invalid-name
|
||||
command_a = "command_a" # pylint: disable=invalid-name
|
||||
command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name
|
||||
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
|
||||
aya = "aya" # pylint: disable=invalid-name
|
||||
alpaca = "alpaca"
|
||||
chatml = "chatml"
|
||||
mistral_v1 = "mistral_v1"
|
||||
mistral_v2v3 = "mistral_v2v3"
|
||||
mistral_v3_tekken = "mistral_v3_tekken"
|
||||
mistral_v7_tekken = "mistral_v7_tekken"
|
||||
gemma = "gemma"
|
||||
cohere = "cohere"
|
||||
llama3 = "llama3"
|
||||
llama3_2_vision = "llama3_2_vision"
|
||||
llama4 = "llama4"
|
||||
phi_3 = "phi_3"
|
||||
phi_35 = "phi_35"
|
||||
deepseek_v2 = "deepseek_v2"
|
||||
deepseek_v3 = "deepseek_v3"
|
||||
jamba = "jamba"
|
||||
jinja = "jinja"
|
||||
qwen_25 = "qwen_25"
|
||||
qwen3 = "qwen3"
|
||||
tokenizer_default = "tokenizer_default"
|
||||
exaone = "exaone"
|
||||
metharme = "metharme"
|
||||
pixtral = "pixtral"
|
||||
llava = "llava"
|
||||
qwen2_vl = "qwen2_vl"
|
||||
gemma3 = "gemma3"
|
||||
command_a = "command_a"
|
||||
command_a_tool_use = "command_a_tool_use"
|
||||
command_a_rag = "command_a_rag"
|
||||
aya = "aya"
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
optimi_adamw = "optimi_adamw"
|
||||
ao_adamw_4bit = "ao_adamw_4bit"
|
||||
ao_adamw_8bit = "ao_adamw_8bit"
|
||||
ao_adamw_fp8 = "ao_adamw_fp8"
|
||||
adopt_adamw = "adopt_adamw"
|
||||
came_pytorch = "came_pytorch"
|
||||
muon = "muon"
|
||||
|
||||
|
||||
class RingAttnFunc(str, Enum):
|
||||
"""Enum class for supported `ring-flash-attn` implementations"""
|
||||
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
VARLEN_LLAMA3 = "varlen_llama3"
|
||||
BATCH_RING = "batch_ring"
|
||||
# VARLEN_RING = "varlen_ring"
|
||||
# VARLEN_ZIGZAG = "varlen_zigzag"
|
||||
# BATCH_ZIGZAG = "batch_zigzag"
|
||||
# BATCH_STRIPE = "batch_stripe"
|
||||
|
||||
@@ -13,10 +13,21 @@ class MLFlowConfig(BaseModel):
|
||||
"""MLFlow configuration subset"""
|
||||
|
||||
use_mlflow: bool | None = None
|
||||
mlflow_tracking_uri: str | None = None
|
||||
mlflow_experiment_name: str | None = None
|
||||
mlflow_run_name: str | None = None
|
||||
hf_mlflow_log_artifacts: bool | None = None
|
||||
mlflow_tracking_uri: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "URI to mlflow"}
|
||||
)
|
||||
mlflow_experiment_name: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your experiment name"}
|
||||
)
|
||||
mlflow_run_name: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your run name"}
|
||||
)
|
||||
hf_mlflow_log_artifacts: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class LISAConfig(BaseModel):
|
||||
@@ -40,13 +51,33 @@ class WandbConfig(BaseModel):
|
||||
"""Wandb configuration subset"""
|
||||
|
||||
use_wandb: bool | None = None
|
||||
wandb_name: str | None = None
|
||||
wandb_run_id: str | None = None
|
||||
wandb_mode: str | None = None
|
||||
wandb_project: str | None = None
|
||||
wandb_entity: str | None = None
|
||||
wandb_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Set the name of your wandb run"},
|
||||
)
|
||||
wandb_run_id: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Set the ID of your wandb run"}
|
||||
)
|
||||
wandb_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb'
|
||||
},
|
||||
)
|
||||
wandb_project: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "Your wandb project name"}
|
||||
)
|
||||
wandb_entity: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "A wandb Team name if using a Team"},
|
||||
)
|
||||
wandb_watch: str | None = None
|
||||
wandb_log_model: str | None = None
|
||||
wandb_log_model: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training'
|
||||
},
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -64,14 +95,52 @@ class WandbConfig(BaseModel):
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: bool | None = None
|
||||
comet_api_key: str | None = None
|
||||
comet_workspace: str | None = None
|
||||
comet_project_name: str | None = None
|
||||
comet_experiment_key: str | None = None
|
||||
comet_mode: str | None = None
|
||||
comet_online: bool | None = None
|
||||
comet_experiment_config: dict[str, Any] | None = None
|
||||
use_comet: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Enable or disable Comet integration."},
|
||||
)
|
||||
comet_api_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "API key for Comet. Recommended to set via `comet login`."
|
||||
},
|
||||
)
|
||||
comet_workspace: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Workspace name in Comet. Defaults to the user's default workspace."
|
||||
},
|
||||
)
|
||||
comet_project_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Project name in Comet. Defaults to Uncategorized."
|
||||
},
|
||||
)
|
||||
comet_experiment_key: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key."
|
||||
},
|
||||
)
|
||||
comet_mode: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.'
|
||||
},
|
||||
)
|
||||
comet_online: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Set to True to log data to Comet server, or False for offline storage. Default is True."
|
||||
},
|
||||
)
|
||||
comet_experiment_config: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Dictionary for additional configuration settings, see the doc for more details."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
|
||||
@@ -12,20 +12,55 @@ class ModelInputConfig(BaseModel):
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
base_model: str
|
||||
base_model_config: str | None = None
|
||||
base_model: str = Field(
|
||||
json_schema_extra={
|
||||
"description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk"
|
||||
}
|
||||
)
|
||||
base_model_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
|
||||
},
|
||||
)
|
||||
cls_model_config: str | None = None
|
||||
tokenizer_config: str | None = None
|
||||
tokenizer_use_fast: bool | None = None
|
||||
tokenizer_legacy: bool | None = None
|
||||
tokenizer_use_mistral_common: bool | None = None
|
||||
tokenizer_config: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model"
|
||||
},
|
||||
)
|
||||
tokenizer_use_fast: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "use_fast option for tokenizer loading from_pretrained, default to True"
|
||||
},
|
||||
)
|
||||
tokenizer_legacy: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use the legacy tokenizer setting, defaults to True"
|
||||
},
|
||||
)
|
||||
tokenizer_use_mistral_common: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer."
|
||||
},
|
||||
)
|
||||
tokenizer_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Corresponding tokenizer for the model AutoTokenizer is a good choice"
|
||||
},
|
||||
)
|
||||
processor_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: bool | None = None
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Trust remote code for untrusted source"},
|
||||
)
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
@@ -40,10 +75,23 @@ class ModelInputConfig(BaseModel):
|
||||
class ModelOutputConfig(BaseModel):
|
||||
"""model save configuration subset"""
|
||||
|
||||
output_dir: str = Field(default="./model-out")
|
||||
hub_model_id: str | None = None
|
||||
hub_strategy: str | None = None
|
||||
save_safetensors: bool | None = True
|
||||
output_dir: str = Field(
|
||||
default="./model-out",
|
||||
json_schema_extra={"description": "Where to save the full-finetuned model to"},
|
||||
)
|
||||
hub_model_id: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "push checkpoints to hub"}
|
||||
)
|
||||
hub_strategy: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how to push checkpoints to hub"},
|
||||
)
|
||||
save_safetensors: bool | None = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Save model as safetensors (require safetensors package). Default True"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SpecialTokensConfig(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
default=4, json_schema_extra={"description": "typically 4 bits"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
|
||||
@@ -17,31 +17,78 @@ class LoftQConfig(BaseModel):
|
||||
class PeftConfig(BaseModel):
|
||||
"""peftq configuration subset"""
|
||||
|
||||
loftq_config: LoftQConfig | None = None
|
||||
loftq_config: LoftQConfig | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Configuration options for loftq initialization for LoRA"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
"""Peft / LoRA configuration subset"""
|
||||
|
||||
load_in_8bit: bool | None = Field(default=False)
|
||||
load_in_4bit: bool | None = Field(default=False)
|
||||
load_in_8bit: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer"
|
||||
},
|
||||
)
|
||||
load_in_4bit: bool | None = Field(
|
||||
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
|
||||
)
|
||||
|
||||
adapter: str | None = None
|
||||
lora_model_dir: str | None = None
|
||||
adapter: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
|
||||
},
|
||||
)
|
||||
lora_model_dir: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`."
|
||||
},
|
||||
)
|
||||
lora_r: int | None = None
|
||||
lora_alpha: int | None = None
|
||||
lora_fan_in_fan_out: bool | None = None
|
||||
lora_target_modules: str | list[str] | None = None
|
||||
lora_target_linear: bool | None = None
|
||||
lora_modules_to_save: list[str] | None = None
|
||||
lora_target_linear: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "If true, will target all linear modules"},
|
||||
)
|
||||
lora_modules_to_save: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities."
|
||||
},
|
||||
)
|
||||
lora_dropout: float | None = 0.0
|
||||
peft_layers_to_transform: list[int] | None = None
|
||||
peft_layers_to_transform: list[int] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The layer indices to transform, otherwise, apply to all layers"
|
||||
},
|
||||
)
|
||||
peft_layers_pattern: list[str] | None = None
|
||||
peft: PeftConfig | None = None
|
||||
peft_use_dora: bool | None = None
|
||||
peft_use_rslora: bool | None = None
|
||||
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||
peft_init_lora_weights: bool | str | None = None
|
||||
peft_use_dora: bool | None = Field(
|
||||
default=None, json_schema_extra={"description": "Whether to use DoRA."}
|
||||
)
|
||||
peft_use_rslora: bool | None = Field(
|
||||
default=None, json_schema_extra={"description": "Whether to use RSLoRA."}
|
||||
)
|
||||
peft_layer_replication: list[tuple[int, int]] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of layer indices to replicate."},
|
||||
)
|
||||
peft_init_lora_weights: bool | str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
|
||||
},
|
||||
)
|
||||
|
||||
qlora_sharded_model_loading: bool | None = Field(
|
||||
default=False,
|
||||
@@ -49,9 +96,24 @@ class LoraConfig(BaseModel):
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: bool | None = None
|
||||
gptq: bool | None = None
|
||||
bnb_config_kwargs: dict[str, Any] | None = None
|
||||
lora_on_cpu: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge"
|
||||
},
|
||||
)
|
||||
gptq: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether you are training a 4-bit GPTQ quantized model"
|
||||
},
|
||||
)
|
||||
bnb_config_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "optional overrides to the bnb 4bit quantization configuration"
|
||||
},
|
||||
)
|
||||
|
||||
loraplus_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
@@ -62,7 +124,7 @@ class LoraConfig(BaseModel):
|
||||
loraplus_lr_embedding: float | None = Field(
|
||||
default=1e-6,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
"description": "loraplus learning rate for lora embedding layers. Default value is 1e-6."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -125,8 +187,29 @@ class LoraConfig(BaseModel):
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
relora_steps: int | None = None
|
||||
relora_warmup_steps: int | None = None
|
||||
relora_anneal_steps: int | None = None
|
||||
relora_prune_ratio: float | None = None
|
||||
relora_cpu_offload: bool | None = None
|
||||
relora_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
|
||||
)
|
||||
relora_warmup_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Number of per-restart warmup steps"},
|
||||
)
|
||||
relora_anneal_steps: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of anneal steps for each relora cycle"
|
||||
},
|
||||
)
|
||||
relora_prune_ratio: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "threshold for optimizer magnitude when pruning"
|
||||
},
|
||||
)
|
||||
relora_cpu_offload: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -15,17 +15,22 @@ class QATConfig(BaseModel):
|
||||
"""
|
||||
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
default=None,
|
||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
||||
)
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
default=TorchIntDType.int8,
|
||||
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=False, description="Quantize embedding"
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
group_size: int | None = Field(
|
||||
default=32,
|
||||
description="The number of elements in each group for per-group fake quantization",
|
||||
)
|
||||
fake_quant_after_n_steps: int | None = Field(
|
||||
default=None, description="Fake quant after n steps"
|
||||
default=None, description="The number of steps to apply fake quantization after"
|
||||
)
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@@ -44,15 +49,20 @@ class PTQConfig(BaseModel):
|
||||
"""
|
||||
|
||||
weight_dtype: TorchIntDType = Field(
|
||||
default=TorchIntDType.int8, description="Weight dtype"
|
||||
default=TorchIntDType.int8,
|
||||
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
|
||||
)
|
||||
activation_dtype: TorchIntDType | None = Field(
|
||||
default=None, description="Activation dtype"
|
||||
default=None,
|
||||
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
|
||||
)
|
||||
quantize_embedding: bool | None = Field(
|
||||
default=None, description="Quantize embedding"
|
||||
default=None, description="Whether to quantize the embedding layer."
|
||||
)
|
||||
group_size: int | None = Field(
|
||||
default=32,
|
||||
description="The number of elements in each group for per-group fake quantization",
|
||||
)
|
||||
group_size: int | None = Field(default=32, description="Group size")
|
||||
|
||||
@field_validator("activation_dtype", "weight_dtype", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -23,10 +23,17 @@ class LrGroup(BaseModel):
|
||||
class HyperparametersConfig(BaseModel):
|
||||
"""Training hyperparams configuration subset"""
|
||||
|
||||
gradient_accumulation_steps: int | None = Field(default=1)
|
||||
gradient_accumulation_steps: int | None = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"description": "If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps."
|
||||
},
|
||||
)
|
||||
micro_batch_size: int | None = Field(
|
||||
default=1,
|
||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||
json_schema_extra={
|
||||
"description": "The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps"
|
||||
},
|
||||
)
|
||||
batch_size: int | None = Field(
|
||||
default=None,
|
||||
@@ -41,45 +48,99 @@ class HyperparametersConfig(BaseModel):
|
||||
},
|
||||
)
|
||||
|
||||
auto_find_batch_size: bool | None = None
|
||||
auto_find_batch_size: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "whether to find batch size that fits in memory. Passed to underlying transformers Trainer"
|
||||
},
|
||||
)
|
||||
|
||||
train_on_inputs: bool | None = False
|
||||
group_by_length: bool | None = None
|
||||
train_on_inputs: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "Whether to mask out or include the human's prompt from the training labels"
|
||||
},
|
||||
)
|
||||
group_by_length: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled."
|
||||
},
|
||||
)
|
||||
|
||||
learning_rate: str | float
|
||||
embedding_lr: float | None = None
|
||||
embedding_lr_scale: float | None = None
|
||||
weight_decay: float | None = 0.0
|
||||
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||
OptimizerNames.ADAMW_TORCH_FUSED
|
||||
weight_decay: float | None = Field(
|
||||
default=0.0, json_schema_extra={"description": "Specify weight decay"}
|
||||
)
|
||||
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field(
|
||||
default=OptimizerNames.ADAMW_TORCH_FUSED,
|
||||
json_schema_extra={"description": "Specify optimizer"},
|
||||
)
|
||||
optim_args: (str | dict[str, Any]) | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
json_schema_extra={
|
||||
"description": "Dictionary of arguments to pass to the optimizer"
|
||||
},
|
||||
)
|
||||
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
"description": "The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm"
|
||||
},
|
||||
)
|
||||
torchdistx_path: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Path to torch distx for optim 'adamw_anyprecision'"
|
||||
},
|
||||
)
|
||||
torchdistx_path: str | None = None
|
||||
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||
SchedulerType.COSINE
|
||||
)
|
||||
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||
lr_scheduler_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specify a scheduler and kwargs to use with the optimizer"
|
||||
},
|
||||
)
|
||||
lr_quadratic_warmup: bool | None = None
|
||||
cosine_min_lr_ratio: float | None = None
|
||||
cosine_constant_lr_ratio: float | None = None
|
||||
lr_div_factor: float | None = None
|
||||
cosine_min_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr"
|
||||
},
|
||||
)
|
||||
cosine_constant_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step"
|
||||
},
|
||||
)
|
||||
lr_div_factor: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Learning rate div factor"}
|
||||
)
|
||||
lr_groups: list[LrGroup] | None = None
|
||||
|
||||
adam_epsilon: float | None = None
|
||||
adam_epsilon2: float | None = None
|
||||
adam_beta1: float | None = None
|
||||
adam_beta2: float | None = None
|
||||
adam_beta3: float | None = None
|
||||
max_grad_norm: float | None = None
|
||||
adam_epsilon: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "adamw hyperparams"}
|
||||
)
|
||||
adam_epsilon2: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
||||
)
|
||||
adam_beta1: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "adamw hyperparams"}
|
||||
)
|
||||
adam_beta2: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "adamw hyperparams"}
|
||||
)
|
||||
adam_beta3: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
|
||||
)
|
||||
max_grad_norm: float | None = Field(
|
||||
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
|
||||
)
|
||||
num_epochs: float = Field(default=1.0)
|
||||
|
||||
@field_validator("batch_size")
|
||||
|
||||
@@ -10,12 +10,14 @@ class TRLConfig(BaseModel):
|
||||
|
||||
beta: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Beta for RL training"},
|
||||
json_schema_extra={
|
||||
"description": "Beta parameter for the RL training. Same as `rl_beta`. Use"
|
||||
},
|
||||
)
|
||||
max_completion_length: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the completion for RL training"
|
||||
"description": "Maximum length of the completion for RL training."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -23,81 +25,69 @@ class TRLConfig(BaseModel):
|
||||
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
|
||||
use_vllm: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training."},
|
||||
)
|
||||
vllm_server_host: str | None = Field(
|
||||
default="0.0.0.0", # nosec B104
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to"},
|
||||
json_schema_extra={"description": "Host of the vLLM server to connect to."},
|
||||
)
|
||||
vllm_server_port: int | None = Field(
|
||||
default=8000,
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to"},
|
||||
json_schema_extra={"description": "Port of the vLLM server to connect to."},
|
||||
)
|
||||
vllm_server_timeout: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
"description": "Total timeout (in seconds) to wait for the vLLM server to respond."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
|
||||
},
|
||||
json_schema_extra={"description": "Regex for vLLM guided decoding."},
|
||||
)
|
||||
|
||||
reward_funcs: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of reward functions to load"},
|
||||
json_schema_extra={
|
||||
"description": "List of reward functions to load. Paths must be importable from current dir."
|
||||
},
|
||||
)
|
||||
reward_weights: list[float] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||
"description": "List of reward weights for the reward functions."
|
||||
},
|
||||
)
|
||||
num_generations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||
},
|
||||
json_schema_extra={"description": "Number of generations to sample."},
|
||||
)
|
||||
log_completions: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to log completions"},
|
||||
json_schema_extra={"description": "Whether to log completions."},
|
||||
)
|
||||
num_completions_to_print: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
|
||||
"description": "Number of completions to print when log_completions is True."
|
||||
},
|
||||
)
|
||||
sync_ref_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
"Whether to sync the reference model every `ref_model_sync_steps` "
|
||||
"steps, using the `ref_model_mixup_alpha` parameter."
|
||||
)
|
||||
},
|
||||
json_schema_extra={"description": "Whether to sync the reference model."},
|
||||
)
|
||||
ref_model_mixup_alpha: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={
|
||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
json_schema_extra={"description": "Mixup alpha for the reference model."},
|
||||
)
|
||||
ref_model_sync_steps: int | None = Field(
|
||||
default=64,
|
||||
json_schema_extra={
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
json_schema_extra={"description": "Sync steps for the reference model."},
|
||||
)
|
||||
scale_rewards: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
|
||||
"description": "Whether to scale rewards by their standard deviation."
|
||||
},
|
||||
)
|
||||
|
||||
@@ -124,13 +114,13 @@ class TRLConfig(BaseModel):
|
||||
repetition_penalty: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
|
||||
"description": "Penalty for tokens that appear in prompt and generated text."
|
||||
},
|
||||
)
|
||||
num_iterations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
|
||||
"description": "Number of iterations per batch (μ) for GRPO."
|
||||
},
|
||||
)
|
||||
epsilon: float | None = Field(
|
||||
@@ -152,12 +142,12 @@ class TRLConfig(BaseModel):
|
||||
loss_type: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
|
||||
"description": "Loss formulation to use. Supported values: grpo, bnpo, dr_grpo."
|
||||
},
|
||||
)
|
||||
mask_truncated_completions: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "When enabled, truncated completions are excluded from the loss calculation."
|
||||
"description": "Whether to exclude truncated completions from loss calculation."
|
||||
},
|
||||
)
|
||||
|
||||
1073
src/axolotl/utils/schemas/validation.py
Normal file
1073
src/axolotl/utils/schemas/validation.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user