Config doc autogen (#2718)

* config reference doc autogen

* improvements

* cleanup; still ugly but working

* reformat

* remove autogen config ref from git

* factor out validations

* rewrite

* rewrite

* cleanup

* progress

* progress

* progress

* lint and minifying somewhat

* remove unneeded

* coderabbit

* coderabbit

* update preview-docs workflow triggers

* installing with deps

* coderabbit

* update refs

* overwrote file accidentally
This commit is contained in:
Dan Saunders
2025-06-18 15:36:53 -04:00
committed by GitHub
parent da8f6c32b9
commit 9d5bfc127e
23 changed files with 3060 additions and 2129 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,8 @@
"""Pydantic models for datasets-related configuration"""
from pydantic import BaseModel, model_validator
from typing import Literal
from pydantic import BaseModel, Field, model_validator
from axolotl.utils.schemas.enums import ChatTemplate
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
@@ -9,57 +11,178 @@ from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
class UserDefinedPrompterType(BaseModel):
"""Structure for user defined prompt types"""
system_prompt: str | None = None
system_format: str | None = None
system_prompt: str | None = Field(
default=None,
json_schema_extra={"description": "Custom user instruction prompt"},
)
system_format: str | None = Field(
default=None,
json_schema_extra={"description": "Use {system} as key to be replaced"},
)
field_system: str | None = None
field_instruction: str | None = None
field_input: str | None = None
field_output: str | None = None
format: str | None = None
no_input_format: str | None = None
field: str | None = None
format: str | None = Field(
default=None,
json_schema_extra={
"description": "Customizable to be single line or multi-line. Use {instruction}/{input} as key to be replaced. 'format' can include {input}"
},
)
no_input_format: str | None = Field(
default=None,
json_schema_extra={"description": "'no_input_format' cannot include {input}"},
)
field: str | None = Field(
default=None,
json_schema_extra={
"description": "For `completion` datsets only, uses the provided field instead of `text` column"
},
)
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: str | None = None
split: str | None = None
type: str | UserDefinedPrompterType | None = None
path: str | None = Field(
default=None,
json_schema_extra={
"description": "HuggingFace dataset repo | s3:// | gs:// | path to local file or directory"
},
)
split: str | None = Field(
default=None,
json_schema_extra={"description": "name of dataset split to load from"},
)
type: str | UserDefinedPrompterType | None = Field(
default=None,
json_schema_extra={
"description": "The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]"
},
)
input_transform: str | None = None
shards: int | None = None
shards_idx: int | None = None
preprocess_shards: int | None = None
shards: int | None = Field(
default=None,
json_schema_extra={
"description": "split dataset into N pieces (use with shards_idx)"
},
)
shards_idx: int | None = Field(
default=None,
json_schema_extra={"description": "the index of sharded dataset to use"},
)
preprocess_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "process dataset in N sequential chunks for memory efficiency (exclusive with `shards`)"
},
)
conversation: str | None = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: ChatTemplate | str | None = None
chat_template_jinja: str | None = None
data_files: str | list[str] | None = None
chat_template: ChatTemplate | str | None = Field(
default=None,
json_schema_extra={
"description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field."
},
)
chat_template_jinja: str | None = Field(
default=None,
json_schema_extra={
"description": "Custom jinja chat template. Used only if `chat_template: jinja` or empty."
},
)
data_files: str | list[str] | None = Field(
default=None, json_schema_extra={"description": "path to source data files"}
)
input_format: str | None = None
name: str | None = None
ds_type: str | None = None
name: str | None = Field(
default=None,
json_schema_extra={"description": "name of dataset configuration to load"},
)
ds_type: str | None = Field(
default=None,
json_schema_extra={"description": "defines the datatype when path is a file"},
)
field: str | None = None
field_human: str | None = None
field_model: str | None = None
field_messages: str | None = None
field_tools: str | None = None
field_messages: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the messages (default: "messages")'
},
)
field_tools: str | None = Field(
default=None,
json_schema_extra={
"description": 'Key containing the tools (default: "tools"). Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).'
},
)
# deprecated, use message_property_mappings
message_field_role: str | None = None
# deprecated, use message_property_mappings
message_field_content: str | None = None
message_property_mappings: dict[str, str] | None = None
message_field_training: str | None = None
message_field_training_detail: str | None = None
split_thinking: bool | None = None
message_property_mappings: dict[str, str] | None = Field(
default=None,
json_schema_extra={
"description": "Mapping of properties from the input dataset to the chat template. (default: message_property_mappings={'role':'role', 'content':'content'}) If a property exists in the template but not in this mapping, the system will attempt to load it directly from the message using the property name as the key. Example: In the mapping below, 'from' is loaded from input dataset and used as 'role', while 'value' is loaded and used as 'content' in the chat template."
},
)
message_field_training: str | None = Field(
default=None,
json_schema_extra={
"description": "The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`."
},
)
message_field_training_detail: str | None = Field(
default=None,
json_schema_extra={
"description": "The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train)."
},
)
split_thinking: bool | None = Field(
default=None,
json_schema_extra={
"description": "(for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags"
},
)
logprobs_field: str | None = None
temperature: float | None = None
roles_to_train: list[str] | None = None
train_on_eos: str | None = None
roles: dict[str, list[str]] | None = None
drop_system_message: bool | None = None
trust_remote_code: bool | None = False
revision: str | None = None
roles_to_train: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "Roles to train on. The tokens from these roles will be considered for the loss."
},
)
train_on_eos: Literal["all", "turn", "last"] | None = Field(
default=None,
json_schema_extra={
"description": "Which EOS tokens to train on in the conversation. Possible values are: all: train on all EOS tokens, turn (default): train on the EOS token at the end of each trainable turn, last: train on the last EOS token in the conversation"
},
)
roles: dict[str, list[str]] | None = Field(
default=None,
json_schema_extra={
"description": 'Roles mapping in the messages. The format is {target_role: [source_roles]}. All source roles will be mapped to the target role. The default is: user: ["human", "user"], assistant: ["gpt", "assistant"], system: ["system"], tool: ["tool"]'
},
)
drop_system_message: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to drop the system turn from the dataset. Only works with chat_template. This does not drop the default system message from chat_template if it exists. If you wish to, we recommend using a custom jinja template with the default system message removed or adding a system turn with empty content."
},
)
trust_remote_code: bool | None = Field(
default=False,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
revision: str | None = Field(
default=None,
json_schema_extra={
"description": "The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets."
},
)
@model_validator(mode="before")
@classmethod

View File

@@ -60,10 +60,30 @@ class RemappedParameters(BaseModel):
"""Parameters that have been remapped to other names"""
overrides_of_model_config: dict[str, Any] | None = Field(
default=None, alias="model_config"
default=None,
alias="model_config",
json_schema_extra={
"description": "optional overrides to the base model configuration"
},
)
overrides_of_model_kwargs: dict[str, Any] | None = Field(
default=None, alias="model_kwargs"
default=None,
alias="model_kwargs",
json_schema_extra={
"description": "optional overrides the base model loading from_pretrained"
},
)
type_of_model: str | None = Field(
default=None,
alias="model_type",
json_schema_extra={
"description": "If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too"
},
)
revision_of_model: str | None = Field(
default=None,
alias="model_revision",
json_schema_extra={
"description": "You can specify to choose a specific model revision from huggingface hub"
},
)
type_of_model: str | None = Field(default=None, alias="model_type")
revision_of_model: str | None = Field(default=None, alias="model_revision")

View File

@@ -1,5 +1,7 @@
"""Enums for Axolotl input config"""
# pylint: disable=invalid-name
from enum import Enum
import torch
@@ -8,81 +10,81 @@ import torch
class TorchIntDType(Enum):
"""Torch integer data types - `getattr` guards against torch < 2.6 which does not support int4"""
uint1 = getattr(torch, "uint1", None) # pylint: disable=invalid-name
uint2 = getattr(torch, "uint2", None) # pylint: disable=invalid-name
uint3 = getattr(torch, "uint3", None) # pylint: disable=invalid-name
uint4 = getattr(torch, "uint4", None) # pylint: disable=invalid-name
uint5 = getattr(torch, "uint5", None) # pylint: disable=invalid-name
uint6 = getattr(torch, "uint6", None) # pylint: disable=invalid-name
uint7 = getattr(torch, "uint7", None) # pylint: disable=invalid-name
int4 = getattr(torch, "int4", None) # pylint: disable=invalid-name
int8 = getattr(torch, "int8", None) # pylint: disable=invalid-name
uint1 = getattr(torch, "uint1", None)
uint2 = getattr(torch, "uint2", None)
uint3 = getattr(torch, "uint3", None)
uint4 = getattr(torch, "uint4", None)
uint5 = getattr(torch, "uint5", None)
uint6 = getattr(torch, "uint6", None)
uint7 = getattr(torch, "uint7", None)
int4 = getattr(torch, "int4", None)
int8 = getattr(torch, "int8", None)
class RLType(str, Enum):
"""RL trainer type configuration subset"""
DPO = "dpo" # pylint: disable=invalid-name
GRPO = "grpo" # pylint: disable=invalid-name
IPO = "ipo" # pylint: disable=invalid-name
ORPO = "orpo" # pylint: disable=invalid-name
KTO = "kto" # pylint: disable=invalid-name
SIMPO = "simpo" # pylint: disable=invalid-name
DPO = "dpo"
GRPO = "grpo"
IPO = "ipo"
ORPO = "orpo"
KTO = "kto"
SIMPO = "simpo"
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
alpaca = "alpaca" # pylint: disable=invalid-name
chatml = "chatml" # pylint: disable=invalid-name
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
mistral_v7_tekken = "mistral_v7_tekken" # pylint: disable=invalid-name
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llama4 = "llama4" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
qwen3 = "qwen3" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
gemma3 = "gemma3" # pylint: disable=invalid-name
command_a = "command_a" # pylint: disable=invalid-name
command_a_tool_use = "command_a_tool_use" # pylint: disable=invalid-name
command_a_rag = "command_a_rag" # pylint: disable=invalid-name
aya = "aya" # pylint: disable=invalid-name
alpaca = "alpaca"
chatml = "chatml"
mistral_v1 = "mistral_v1"
mistral_v2v3 = "mistral_v2v3"
mistral_v3_tekken = "mistral_v3_tekken"
mistral_v7_tekken = "mistral_v7_tekken"
gemma = "gemma"
cohere = "cohere"
llama3 = "llama3"
llama3_2_vision = "llama3_2_vision"
llama4 = "llama4"
phi_3 = "phi_3"
phi_35 = "phi_35"
deepseek_v2 = "deepseek_v2"
deepseek_v3 = "deepseek_v3"
jamba = "jamba"
jinja = "jinja"
qwen_25 = "qwen_25"
qwen3 = "qwen3"
tokenizer_default = "tokenizer_default"
exaone = "exaone"
metharme = "metharme"
pixtral = "pixtral"
llava = "llava"
qwen2_vl = "qwen2_vl"
gemma3 = "gemma3"
command_a = "command_a"
command_a_tool_use = "command_a_tool_use"
command_a_rag = "command_a_rag"
aya = "aya"
class CustomSupportedOptimizers(str, Enum):
"""Custom supported optimizers"""
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name
optimi_adamw = "optimi_adamw"
ao_adamw_4bit = "ao_adamw_4bit"
ao_adamw_8bit = "ao_adamw_8bit"
ao_adamw_fp8 = "ao_adamw_fp8"
adopt_adamw = "adopt_adamw"
came_pytorch = "came_pytorch"
muon = "muon"
class RingAttnFunc(str, Enum):
"""Enum class for supported `ring-flash-attn` implementations"""
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
VARLEN_LLAMA3 = "varlen_llama3"
BATCH_RING = "batch_ring"
# VARLEN_RING = "varlen_ring"
# VARLEN_ZIGZAG = "varlen_zigzag"
# BATCH_ZIGZAG = "batch_zigzag"
# BATCH_STRIPE = "batch_stripe"

View File

@@ -13,10 +13,21 @@ class MLFlowConfig(BaseModel):
"""MLFlow configuration subset"""
use_mlflow: bool | None = None
mlflow_tracking_uri: str | None = None
mlflow_experiment_name: str | None = None
mlflow_run_name: str | None = None
hf_mlflow_log_artifacts: bool | None = None
mlflow_tracking_uri: str | None = Field(
default=None, json_schema_extra={"description": "URI to mlflow"}
)
mlflow_experiment_name: str | None = Field(
default=None, json_schema_extra={"description": "Your experiment name"}
)
mlflow_run_name: str | None = Field(
default=None, json_schema_extra={"description": "Your run name"}
)
hf_mlflow_log_artifacts: bool | None = Field(
default=None,
json_schema_extra={
"description": "set to true to copy each saved checkpoint on each save to mlflow artifact registry"
},
)
class LISAConfig(BaseModel):
@@ -40,13 +51,33 @@ class WandbConfig(BaseModel):
"""Wandb configuration subset"""
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None
wandb_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your wandb run"},
)
wandb_run_id: str | None = Field(
default=None, json_schema_extra={"description": "Set the ID of your wandb run"}
)
wandb_mode: str | None = Field(
default=None,
json_schema_extra={
"description": '"offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb'
},
)
wandb_project: str | None = Field(
default=None, json_schema_extra={"description": "Your wandb project name"}
)
wandb_entity: str | None = Field(
default=None,
json_schema_extra={"description": "A wandb Team name if using a Team"},
)
wandb_watch: str | None = None
wandb_log_model: str | None = None
wandb_log_model: str | None = Field(
default=None,
json_schema_extra={
"description": '"checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training'
},
)
@model_validator(mode="before")
@classmethod
@@ -64,14 +95,52 @@ class WandbConfig(BaseModel):
class CometConfig(BaseModel):
"""Comet configuration subset"""
use_comet: bool | None = None
comet_api_key: str | None = None
comet_workspace: str | None = None
comet_project_name: str | None = None
comet_experiment_key: str | None = None
comet_mode: str | None = None
comet_online: bool | None = None
comet_experiment_config: dict[str, Any] | None = None
use_comet: bool | None = Field(
default=None,
json_schema_extra={"description": "Enable or disable Comet integration."},
)
comet_api_key: str | None = Field(
default=None,
json_schema_extra={
"description": "API key for Comet. Recommended to set via `comet login`."
},
)
comet_workspace: str | None = Field(
default=None,
json_schema_extra={
"description": "Workspace name in Comet. Defaults to the user's default workspace."
},
)
comet_project_name: str | None = Field(
default=None,
json_schema_extra={
"description": "Project name in Comet. Defaults to Uncategorized."
},
)
comet_experiment_key: str | None = Field(
default=None,
json_schema_extra={
"description": "Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key."
},
)
comet_mode: str | None = Field(
default=None,
json_schema_extra={
"description": 'Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.'
},
)
comet_online: bool | None = Field(
default=None,
json_schema_extra={
"description": "Set to True to log data to Comet server, or False for offline storage. Default is True."
},
)
comet_experiment_config: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Dictionary for additional configuration settings, see the doc for more details."
},
)
class GradioConfig(BaseModel):

View File

@@ -12,20 +12,55 @@ class ModelInputConfig(BaseModel):
model_config = {"protected_namespaces": ()}
base_model: str
base_model_config: str | None = None
base_model: str = Field(
json_schema_extra={
"description": "This is the huggingface model that contains *.pt, *.safetensors, or *.bin files. This can also be a relative path to a model on disk"
}
)
base_model_config: str | None = Field(
default=None,
json_schema_extra={
"description": "If the base_model repo on hf hub doesn't include configuration .json files, You can set that here, or leave this empty to default to base_model"
},
)
cls_model_config: str | None = None
tokenizer_config: str | None = None
tokenizer_use_fast: bool | None = None
tokenizer_legacy: bool | None = None
tokenizer_use_mistral_common: bool | None = None
tokenizer_config: str | None = Field(
default=None,
json_schema_extra={
"description": "Optional tokenizer configuration path in case you want to use a different tokenizer than the one defined in the base model"
},
)
tokenizer_use_fast: bool | None = Field(
default=None,
json_schema_extra={
"description": "use_fast option for tokenizer loading from_pretrained, default to True"
},
)
tokenizer_legacy: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use the legacy tokenizer setting, defaults to True"
},
)
tokenizer_use_mistral_common: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer."
},
)
tokenizer_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
default=None,
json_schema_extra={
"description": "Corresponding tokenizer for the model AutoTokenizer is a good choice"
},
)
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
trust_remote_code: bool | None = None
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
@field_validator("trust_remote_code")
@classmethod
@@ -40,10 +75,23 @@ class ModelInputConfig(BaseModel):
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
output_dir: str = Field(default="./model-out")
hub_model_id: str | None = None
hub_strategy: str | None = None
save_safetensors: bool | None = True
output_dir: str = Field(
default="./model-out",
json_schema_extra={"description": "Where to save the full-finetuned model to"},
)
hub_model_id: str | None = Field(
default=None, json_schema_extra={"description": "push checkpoints to hub"}
)
hub_strategy: str | None = Field(
default=None,
json_schema_extra={"description": "how to push checkpoints to hub"},
)
save_safetensors: bool | None = Field(
default=True,
json_schema_extra={
"description": "Save model as safetensors (require safetensors package). Default True"
},
)
class SpecialTokensConfig(BaseModel):

View File

@@ -9,7 +9,7 @@ class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
default=4, json_schema_extra={"description": "typically 4 bits"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
@@ -17,31 +17,78 @@ class LoftQConfig(BaseModel):
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: LoftQConfig | None = None
loftq_config: LoftQConfig | None = Field(
default=None,
json_schema_extra={
"description": "Configuration options for loftq initialization for LoRA"
},
)
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: bool | None = Field(default=False)
load_in_4bit: bool | None = Field(default=False)
load_in_8bit: bool | None = Field(
default=False,
json_schema_extra={
"description": "This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer"
},
)
load_in_4bit: bool | None = Field(
default=False, json_schema_extra={"description": "Use bitsandbytes 4 bit"}
)
adapter: str | None = None
lora_model_dir: str | None = None
adapter: str | None = Field(
default=None,
json_schema_extra={
"description": "If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model"
},
)
lora_model_dir: str | None = Field(
default=None,
json_schema_extra={
"description": "If you already have a lora model trained that you want to load, put that here. This means after training, if you want to test the model, you should set this to the value of `output_dir`. Note that if you merge an adapter to the base model, a new subdirectory `merged` will be created under the `output_dir`."
},
)
lora_r: int | None = None
lora_alpha: int | None = None
lora_fan_in_fan_out: bool | None = None
lora_target_modules: str | list[str] | None = None
lora_target_linear: bool | None = None
lora_modules_to_save: list[str] | None = None
lora_target_linear: bool | None = Field(
default=None,
json_schema_extra={"description": "If true, will target all linear modules"},
)
lora_modules_to_save: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens. For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models. `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities."
},
)
lora_dropout: float | None = 0.0
peft_layers_to_transform: list[int] | None = None
peft_layers_to_transform: list[int] | None = Field(
default=None,
json_schema_extra={
"description": "The layer indices to transform, otherwise, apply to all layers"
},
)
peft_layers_pattern: list[str] | None = None
peft: PeftConfig | None = None
peft_use_dora: bool | None = None
peft_use_rslora: bool | None = None
peft_layer_replication: list[tuple[int, int]] | None = None
peft_init_lora_weights: bool | str | None = None
peft_use_dora: bool | None = Field(
default=None, json_schema_extra={"description": "Whether to use DoRA."}
)
peft_use_rslora: bool | None = Field(
default=None, json_schema_extra={"description": "Whether to use RSLoRA."}
)
peft_layer_replication: list[tuple[int, int]] | None = Field(
default=None,
json_schema_extra={"description": "List of layer indices to replicate."},
)
peft_init_lora_weights: bool | str | None = Field(
default=None,
json_schema_extra={
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,
@@ -49,9 +96,24 @@ class LoraConfig(BaseModel):
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: bool | None = None
gptq: bool | None = None
bnb_config_kwargs: dict[str, Any] | None = None
lora_on_cpu: bool | None = Field(
default=None,
json_schema_extra={
"description": "Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge"
},
)
gptq: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether you are training a 4-bit GPTQ quantized model"
},
)
bnb_config_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "optional overrides to the bnb 4bit quantization configuration"
},
)
loraplus_lr_ratio: float | None = Field(
default=None,
@@ -62,7 +124,7 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding: float | None = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
"description": "loraplus learning rate for lora embedding layers. Default value is 1e-6."
},
)
@@ -125,8 +187,29 @@ class LoraConfig(BaseModel):
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: int | None = None
relora_warmup_steps: int | None = None
relora_anneal_steps: int | None = None
relora_prune_ratio: float | None = None
relora_cpu_offload: bool | None = None
relora_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of steps per ReLoRA restart"},
)
relora_warmup_steps: int | None = Field(
default=None,
json_schema_extra={"description": "Number of per-restart warmup steps"},
)
relora_anneal_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of anneal steps for each relora cycle"
},
)
relora_prune_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "threshold for optimizer magnitude when pruning"
},
)
relora_cpu_offload: bool | None = Field(
default=None,
json_schema_extra={
"description": "True to perform lora weight merges on cpu during restarts, for modest gpu memory savings"
},
)

View File

@@ -15,17 +15,22 @@ class QATConfig(BaseModel):
"""
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
default=TorchIntDType.int8,
description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=False, description="Quantize embedding"
)
group_size: int | None = Field(default=32, description="Group size")
group_size: int | None = Field(
default=32,
description="The number of elements in each group for per-group fake quantization",
)
fake_quant_after_n_steps: int | None = Field(
default=None, description="Fake quant after n steps"
default=None, description="The number of steps to apply fake quantization after"
)
@field_validator("activation_dtype", "weight_dtype", mode="before")
@@ -44,15 +49,20 @@ class PTQConfig(BaseModel):
"""
weight_dtype: TorchIntDType = Field(
default=TorchIntDType.int8, description="Weight dtype"
default=TorchIntDType.int8,
description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
)
activation_dtype: TorchIntDType | None = Field(
default=None, description="Activation dtype"
default=None,
description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
)
quantize_embedding: bool | None = Field(
default=None, description="Quantize embedding"
default=None, description="Whether to quantize the embedding layer."
)
group_size: int | None = Field(
default=32,
description="The number of elements in each group for per-group fake quantization",
)
group_size: int | None = Field(default=32, description="Group size")
@field_validator("activation_dtype", "weight_dtype", mode="before")
@classmethod

View File

@@ -23,10 +23,17 @@ class LrGroup(BaseModel):
class HyperparametersConfig(BaseModel):
"""Training hyperparams configuration subset"""
gradient_accumulation_steps: int | None = Field(default=1)
gradient_accumulation_steps: int | None = Field(
default=1,
json_schema_extra={
"description": "If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps."
},
)
micro_batch_size: int | None = Field(
default=1,
json_schema_extra={"description": "per gpu micro batch size for training"},
json_schema_extra={
"description": "The number of samples to include in each batch. This is the number of samples sent to each GPU. Batch size per gpu = micro_batch_size * gradient_accumulation_steps"
},
)
batch_size: int | None = Field(
default=None,
@@ -41,45 +48,99 @@ class HyperparametersConfig(BaseModel):
},
)
auto_find_batch_size: bool | None = None
auto_find_batch_size: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to find batch size that fits in memory. Passed to underlying transformers Trainer"
},
)
train_on_inputs: bool | None = False
group_by_length: bool | None = None
train_on_inputs: bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to mask out or include the human's prompt from the training labels"
},
)
group_by_length: bool | None = Field(
default=None,
json_schema_extra={
"description": "Group similarly sized data to minimize padding. May be slower to start, as it must download and sort the entire dataset. Note that training loss may have an oscillating pattern with this enabled."
},
)
learning_rate: str | float
embedding_lr: float | None = None
embedding_lr_scale: float | None = None
weight_decay: float | None = 0.0
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
OptimizerNames.ADAMW_TORCH_FUSED
weight_decay: float | None = Field(
default=0.0, json_schema_extra={"description": "Specify weight decay"}
)
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = Field(
default=OptimizerNames.ADAMW_TORCH_FUSED,
json_schema_extra={"description": "Specify optimizer"},
)
optim_args: (str | dict[str, Any]) | None = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
json_schema_extra={
"description": "Dictionary of arguments to pass to the optimizer"
},
)
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
default=None,
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
"description": "The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm"
},
)
torchdistx_path: str | None = Field(
default=None,
json_schema_extra={
"description": "Path to torch distx for optim 'adamw_anyprecision'"
},
)
torchdistx_path: str | None = None
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
SchedulerType.COSINE
)
lr_scheduler_kwargs: dict[str, Any] | None = None
lr_scheduler_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Specify a scheduler and kwargs to use with the optimizer"
},
)
lr_quadratic_warmup: bool | None = None
cosine_min_lr_ratio: float | None = None
cosine_constant_lr_ratio: float | None = None
lr_div_factor: float | None = None
cosine_min_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr"
},
)
cosine_constant_lr_ratio: float | None = Field(
default=None,
json_schema_extra={
"description": "freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step"
},
)
lr_div_factor: float | None = Field(
default=None, json_schema_extra={"description": "Learning rate div factor"}
)
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
adam_epsilon: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_epsilon2: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
adam_beta1: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_beta2: float | None = Field(
default=None, json_schema_extra={"description": "adamw hyperparams"}
)
adam_beta3: float | None = Field(
default=None, json_schema_extra={"description": "only used for CAME Optimizer"}
)
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")

View File

@@ -10,12 +10,14 @@ class TRLConfig(BaseModel):
beta: float | None = Field(
default=None,
json_schema_extra={"description": "Beta for RL training"},
json_schema_extra={
"description": "Beta parameter for the RL training. Same as `rl_beta`. Use"
},
)
max_completion_length: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum length of the completion for RL training"
"description": "Maximum length of the completion for RL training."
},
)
@@ -23,81 +25,69 @@ class TRLConfig(BaseModel):
# Ref: https://github.com/huggingface/trl/blob/26d86757a7c7e24e397ea44f57ecce6031dfac01/trl/trainer/grpo_config.py#L23
use_vllm: bool = Field(
default=False,
json_schema_extra={"description": "Whether to use VLLM for RL training"},
json_schema_extra={"description": "Whether to use VLLM for RL training."},
)
vllm_server_host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host of the vLLM server to connect to"},
json_schema_extra={"description": "Host of the vLLM server to connect to."},
)
vllm_server_port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to connect to"},
json_schema_extra={"description": "Port of the vLLM server to connect to."},
)
vllm_server_timeout: int | None = Field(
default=None,
json_schema_extra={
"description": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
"after the timeout, a `ConnectionError` is raised."
"description": "Total timeout (in seconds) to wait for the vLLM server to respond."
},
)
vllm_guided_decoding_regex: str | None = Field(
default=None,
json_schema_extra={
"description": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."
},
json_schema_extra={"description": "Regex for vLLM guided decoding."},
)
reward_funcs: list[str] | None = Field(
default=None,
json_schema_extra={"description": "List of reward functions to load"},
json_schema_extra={
"description": "List of reward functions to load. Paths must be importable from current dir."
},
)
reward_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for each reward function. Must match the number of reward functions."
"description": "List of reward weights for the reward functions."
},
)
num_generations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
},
json_schema_extra={"description": "Number of generations to sample."},
)
log_completions: bool | None = Field(
default=False,
json_schema_extra={"description": "Whether to log completions"},
json_schema_extra={"description": "Whether to log completions."},
)
num_completions_to_print: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of completions to print. If `log_completions` is `True`, this will be the number of completions logged."
"description": "Number of completions to print when log_completions is True."
},
)
sync_ref_model: bool | None = Field(
default=False,
json_schema_extra={
"description": (
"Whether to sync the reference model every `ref_model_sync_steps` "
"steps, using the `ref_model_mixup_alpha` parameter."
)
},
json_schema_extra={"description": "Whether to sync the reference model."},
)
ref_model_mixup_alpha: float | None = Field(
default=0.9,
json_schema_extra={
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
},
json_schema_extra={"description": "Mixup alpha for the reference model."},
)
ref_model_sync_steps: int | None = Field(
default=64,
json_schema_extra={
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
},
json_schema_extra={"description": "Sync steps for the reference model."},
)
scale_rewards: bool = Field(
default=True,
json_schema_extra={
"description": "Whether to scale the rewards for GRPO by dividing them by their standard deviation."
"description": "Whether to scale rewards by their standard deviation."
},
)
@@ -124,13 +114,13 @@ class TRLConfig(BaseModel):
repetition_penalty: float | None = Field(
default=None,
json_schema_extra={
"description": "Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far."
"description": "Penalty for tokens that appear in prompt and generated text."
},
)
num_iterations: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of iterations per batch (denoted as μ in the algorithm) for GRPO."
"description": "Number of iterations per batch (μ) for GRPO."
},
)
epsilon: float | None = Field(
@@ -152,12 +142,12 @@ class TRLConfig(BaseModel):
loss_type: str | None = Field(
default=None,
json_schema_extra={
"description": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`."
"description": "Loss formulation to use. Supported values: grpo, bnpo, dr_grpo."
},
)
mask_truncated_completions: bool = Field(
default=False,
json_schema_extra={
"description": "When enabled, truncated completions are excluded from the loss calculation."
"description": "Whether to exclude truncated completions from loss calculation."
},
)

File diff suppressed because it is too large Load Diff