Files
axolotl/src/axolotl/utils/schemas/config.py

1236 lines
47 KiB
Python

"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
from typing import Annotated, Any, Literal
from annotated_types import MinLen
from packaging import version
from pydantic import (
BaseModel,
Field,
StringConstraints,
field_serializer,
model_validator,
)
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
DatasetConfig,
DPODataset,
KTODataset,
PretrainingDataset,
SFTDataset,
StepwiseSupervisedDataset,
)
from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters
from axolotl.utils.schemas.enums import ChatTemplate, RingAttnFunc, RLType
from axolotl.utils.schemas.integrations import (
CometConfig,
GradioConfig,
LISAConfig,
MLFlowConfig,
RayConfig,
WandbConfig,
)
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
from axolotl.utils.schemas.model import (
ModelInputConfig,
ModelOutputConfig,
SpecialTokensConfig,
)
from axolotl.utils.schemas.multimodal import MultiModalConfig
from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig
from axolotl.utils.schemas.quantization import PTQConfig, QATConfig
from axolotl.utils.schemas.training import HyperparametersConfig
from axolotl.utils.schemas.trl import TRLConfig
from axolotl.utils.schemas.validation import ValidationMixin
from axolotl.utils.schemas.vllm import VllmConfig
LOG = get_logger(__name__)
# pylint: disable=too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
CometConfig,
LISAConfig,
GradioConfig,
RayConfig,
MultiModalConfig,
RemappedParameters,
DeprecatedParameters,
ValidationMixin,
BaseModel,
):
"""Wrapper of all config options."""
model_config = {"populate_by_name": True}
strict: bool | None = Field(
default=False,
json_schema_extra={"description": "Allow overwrite yml config using from cli"},
)
resume_from_checkpoint: str | None = Field(
default=None,
json_schema_extra={"description": "Resume from a specific checkpoint dir"},
)
auto_resume_from_checkpoints: bool | None = Field(
default=None,
json_schema_extra={
"description": "If resume_from_checkpoint isn't set and you simply want it to start where it left off. Be careful with this being turned on between different models."
},
)
resize_token_embeddings_to_32x: bool | None = Field(
default=None,
json_schema_extra={
"description": "Resize the model embeddings when new tokens are added to multiples of 32. This is reported to improve training speed on some models"
},
)
mean_resizing_embeddings: bool | None = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink."
},
)
embeddings_skip_upcast: bool | None = Field(
default=None,
json_schema_extra={
"description": "Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs"
},
)
rl: RLType | None = Field(
default=None,
json_schema_extra={
"description": "Use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'"
},
)
trl: TRLConfig | None = Field(
default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda
)
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(), # pylint: disable=unnecessary-lambda
)
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(
default=None,
json_schema_extra={"description": "Reward modelling: `True` or `False`"},
)
process_reward_model: bool | None = Field(
default=None,
json_schema_extra={
"description": "Process reward modelling: `True` or `False`"
},
)
num_labels: int | None = None
# Whether to use weighting in DPO trainer.
# If `None`, default is `False` in the trainer.
dpo_use_weighting: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to perform weighting in DPO trainer"
},
)
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
) = Field(
default=None,
json_schema_extra={
"description": "A list of one or more datasets to finetune the model with"
},
)
test_datasets: (
Annotated[
list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset],
MinLen(1),
]
| None
) = Field(
default=None,
json_schema_extra={
"description": "A list of one or more datasets to eval the model with. You can use either test_datasets, or val_set_size, but not both."
},
)
shuffle_merged_datasets: bool | None = Field(
default=True,
json_schema_extra={
"description": "If false, the datasets will not be shuffled and will keep their original order in `datasets`. The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true."
},
)
shuffle_before_merging_datasets: bool | None = Field(
default=False,
json_schema_extra={
"description": "If true, each dataset in `datasets` will be shuffled before merging. This allows curriculum learning strategies to be applied at the dataset level. Default is false."
},
)
dataset_prepared_path: str | None = Field(
default=None,
json_schema_extra={
"description": "Axolotl attempts to save the dataset as an arrow after packing the data together so subsequent training attempts load faster, relative path"
},
)
dataset_shard_num: int | None = Field(
default=None, json_schema_extra={"description": "Num shards for whole dataset"}
)
dataset_shard_idx: int | None = Field(
default=None,
json_schema_extra={"description": "Index of shard to use for whole dataset"},
)
skip_prepare_dataset: bool | None = False
num_dataset_shards_to_save: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to save the prepared dataset"
},
)
pretraining_dataset: (
Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None
) = Field(
default=None,
json_schema_extra={
"description": "Set to HF dataset for type: 'completion' for streaming instead of pre-tokenize"
},
)
dataset_processes: int | None = Field(
default=None,
json_schema_extra={
"description": (
"The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
"For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
)
},
)
dataset_exact_deduplication: bool | None = Field(
default=None,
json_schema_extra={
"description": "Deduplicates datasets and test_datasets with identical entries"
},
)
dataset_keep_in_memory: bool | None = Field(
default=None,
json_schema_extra={
"description": "Keep dataset in memory while preprocessing. Only needed if cached dataset is taking too much storage"
},
)
dataloader_pin_memory: bool | None = None
dataloader_num_workers: int | None = None
dataloader_prefetch_factor: int | None = None
dataloader_drop_last: bool | None = None
accelerator_config: dict[str, Any] | None = None
remove_unused_columns: bool | None = None
push_dataset_to_hub: str | None = Field(
default=None,
json_schema_extra={
"description": "Push prepared dataset to hub - repo_org/repo_name"
},
)
hf_use_auth_token: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets. Required to be true when used in combination with `push_dataset_to_hub`"
},
)
device: Any | None = None
device_map: Any | None = Field(
default=None,
json_schema_extra={
"description": "Passed through to transformers when loading the model when launched without accelerate. Use `sequential` when training w/ model parallelism to limit memory"
},
)
world_size: int | None = None
local_rank: int | None = Field(
default=None,
json_schema_extra={
"description": "Don't mess with this, it's here for accelerate and torchrun"
},
)
ddp: bool | None = None
seed: int | None = Field(
default=None, json_schema_extra={"description": "Seed for reproducibility"}
)
ddp_timeout: int | None = Field(
default=None,
json_schema_extra={"description": "Advanced DDP Arguments - timeout"},
)
ddp_bucket_cap_mb: int | None = Field(
default=None,
json_schema_extra={"description": "Advanced DDP Arguments - bucket cap in MB"},
)
ddp_broadcast_buffers: bool | None = Field(
default=None,
json_schema_extra={"description": "Advanced DDP Arguments - broadcast buffers"},
)
ddp_find_unused_parameters: bool | None = None
eval_table_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0"
},
)
eval_max_new_tokens: int | None = Field(
default=None,
json_schema_extra={
"description": "Total number of tokens generated for predictions sent to wandb. Default is 128"
},
)
do_causal_lm_eval: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`"
},
)
eval_causal_lm_metrics: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "HF evaluate metrics used during evaluation. Default is ['sacrebleu', 'comet', 'ter', 'chrf', 'perplexity']"
},
)
do_bench_eval: bool | None = None
bench_dataset: str | None = None
bench_split: str | None = None
metric_for_best_model: str | None = None
greater_is_better: bool | None = None
loss_watchdog_threshold: float | None = Field(
default=None,
json_schema_extra={
"description": "High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)"
},
)
loss_watchdog_patience: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of high-loss steps in a row before the trainer aborts (default: 3)"
},
)
gc_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Run garbage collection every `gc_steps` steps. -1 will run on epoch end and before evaluations. Default is 0 (disabled)."
},
)
bf16: Literal["auto"] | bool | None = Field(
default="auto",
json_schema_extra={
"description": "Use CUDA bf16. bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere"
},
)
fp16: bool | None = Field(
default=None, json_schema_extra={"description": "Use CUDA fp16"}
)
fp8: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FP8 mixed precision training using TorchAO. Best "
"used in combination with torch.compile."
},
)
fp8_enable_fsdp_float8_all_gather: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable FSDP float8 all-gather optimization for FP8 training. Can "
"improve training speed by 10-15% when FSDP is enabled."
},
)
bfloat16: bool | None = Field(
default=None,
json_schema_extra={
"description": "No AMP (automatic mixed precision) - require >=ampere"
},
) # for non-AMP cases
float16: bool | None = Field(
default=None,
json_schema_extra={"description": "No AMP (automatic mixed precision)"},
) # for non-AMP cases
tf32: bool | None = Field(
default=None,
json_schema_extra={"description": "Use CUDA tf32 - require >=ampere"},
)
float32: bool | None = None
gradient_checkpointing: Literal["offload", "offload_disk"] | bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to use gradient checkpointing. Available options are: true, false, 'offload', 'offload_disk'. https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing"
},
)
gradient_checkpointing_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional kwargs to pass to the trainer for gradient checkpointing"
},
)
activation_offloading: Literal["legacy", "disk"] | bool | None = Field(
default=False,
json_schema_extra={
"description": "Whether to offload activations. Available options are: true, false, 'legacy', 'disk'."
},
)
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(
default=512,
json_schema_extra={
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
},
)
min_sample_len: int | None = None
max_prompt_len: int = Field(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
)
sample_packing: bool | None = Field(
default=None,
json_schema_extra={
"description": "Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'"
},
)
sample_packing_group_size: int | None = Field(
default=100_000,
json_schema_extra={
"description": "The number of samples packed at a time. Increasing the following values helps with packing, but usually only slightly (<%1.)"
},
)
sample_packing_bin_size: int | None = Field(
default=200,
json_schema_extra={
"description": "The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples."
},
)
sample_packing_sequentially: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to pack samples sequentially"},
)
sample_packing_mp_start_method: str | None = Field(
default=None,
json_schema_extra={
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
},
)
eval_sample_packing: bool | None = Field(
default=None,
json_schema_extra={
"description": "Set to 'false' if getting errors during eval with sample_packing on"
},
)
pad_to_sequence_len: bool | None = Field(
default=None,
json_schema_extra={
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
},
)
curriculum_sampling: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use sequential sampling for curriculum learning"
},
)
multipack_real_batches: bool | None = None
pretraining_sample_concatenation: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to concatenate samples during pretraining",
},
)
batch_flattening: Literal["auto"] | bool | None = Field(
default=None,
json_schema_extra={
"description": "Use batch flattening for speedups when not using sample_packing"
},
)
# for PoSE context length extension
use_pose: bool | None = None
pose_split_on_token_ids: list[int] | None = None
pose_max_context_len: int | None = None
pose_num_chunks: int | None = None
pretrain_multipack_buffer_size: int | None = 10_000
pretrain_multipack_attn: bool | None = Field(
default=True,
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
},
)
xformers_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use xformers attention patch https://github.com/facebookresearch/xformers"
},
)
sdp_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use scaled-dot-product attention https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html"
},
)
s2_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf"
},
)
flex_attention: bool | None = None
flex_attn_compile_kwargs: dict[str, Any] | None = None
flash_attention: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention"
},
)
flash_attn_cross_entropy: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use flash-attention cross entropy implementation - advanced use only"
},
)
flash_attn_rms_norm: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use flash-attention rms norm implementation - advanced use only"
},
)
flash_attn_fuse_qkv: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to fuse QKV into a single operation"
},
)
flash_attn_fuse_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to fuse part of the MLP into a single operation"
},
)
flash_optimum: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use bettertransformers"},
)
eager_attention: bool | None = None
unsloth_cross_entropy_loss: bool | None = None
unsloth_lora_mlp: bool | None = None
unsloth_lora_qkv: bool | None = None
unsloth_lora_o: bool | None = None
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
lora_mlp_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_qkv_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
lora_o_kernel: bool | None = Field(
default=None,
json_schema_extra={
"description": "Apply custom LoRA autograd functions and activation function Triton kernels for speed and memory savings. See: https://docs.axolotl.ai/docs/lora_optims.html"
},
)
chunked_cross_entropy: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use chunked cross entropy loss for memory efficiency"
},
)
chunked_cross_entropy_num_chunks: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of chunks to use for chunked cross entropy loss"
},
)
tiled_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use ALST tiled mlp for memory efficient long context"
},
)
tiled_mlp_num_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
},
)
tiled_mlp_use_original_mlp: bool | None = Field(
default=True,
json_schema_extra={
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
},
)
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Deepspeed config path. e.g., deepspeed_configs/zero3.json"
},
)
deepcompile: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use deepcompile for faster training with deepspeed"
},
)
fsdp: list[str] | None = Field(
default=None,
json_schema_extra={"description": "FSDP configuration"},
deprecated="Configuring FSDP using `fsdp` is deprecated. Please use `fsdp_config` instead. ",
)
# TODO @SalmanMohammadi strongly type this as its own schema
fsdp_config: dict[str, Any] | None = Field(
default=None, json_schema_extra={"description": "FSDP configuration options"}
)
fsdp_version: int | None = Field(
default=None,
json_schema_extra={"description": "FSDP version"},
)
fsdp_final_state_dict_type: (
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None
) = Field(
default=None,
deprecated="Configuring FSDP final state dict type using `fsdp_final_state_dict_type` is deprecated. Please use `fsdp_config.final_state_dict_type` instead.",
)
val_set_size: float | None = Field(
default=0.0,
json_schema_extra={
"description": "How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval."
},
)
sequence_parallel_degree: int | None = Field(
default=None,
json_schema_extra={
"description": "Set to a divisor of the number of GPUs available to split sequences into chunks of equal size. Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM. E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized subsequences, or set to 4 to split into four equal-sized subsequences. See https://docs.axolotl.ai/docs/sequence_parallelism.html for more details."
},
)
heads_k_stride: int | None = Field(
default=None,
json_schema_extra={
"description": "Optional; strides across the key dimension. Larger values use more memory but should make training faster. Must evenly divide the number of KV heads in your model."
},
)
ring_attn_func: RingAttnFunc | None = Field(
default=None,
json_schema_extra={
"description": "One of 'varlen_llama3', 'batch_ring', 'batch_zigzag', 'batch_stripe'. Defaults to 'varlen_llama3' in the sample packing case, and 'batch_ring' in the non-sample packing case."
},
)
tensor_parallel_size: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of tensor parallel processes in TP group. Only supported with DeepSpeed AutoTP."
},
)
special_tokens: SpecialTokensConfig | None = Field(
default=None,
json_schema_extra={
"description": "Add or change special tokens. If you add tokens here, you don't need to add them to the `tokens` list."
},
)
tokens: list[str] | None = Field(
default=None,
json_schema_extra={"description": "Add extra tokens to the tokenizer"},
)
added_tokens_overrides: dict[int, str] | None = Field(
default=None,
json_schema_extra={
"description": "Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer. Only works for tokens that are not part of the base vocab (aka are added_tokens). Can be checked if they exist in tokenizer.json added_tokens."
},
)
torch_compile: Literal["auto"] | bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
},
)
torch_compile_backend: str | None = Field(
default=None,
json_schema_extra={"description": "Backend to use for torch.compile"},
)
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
None
)
max_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Maximum number of iterations to train for. It precedes num_epochs which means that if both are set, num_epochs will not be guaranteed. e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps"
},
)
warmup_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of warmup steps. Cannot use with warmup_ratio"
},
)
warmup_ratio: float | None = Field(
default=None,
json_schema_extra={"description": "Warmup ratio. Cannot use with warmup_steps"},
)
eval_steps: int | float | None = Field(
default=None,
json_schema_extra={
"description": "Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps"
},
)
evals_per_epoch: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of times per epoch to run evals, mutually exclusive with eval_steps"
},
)
eval_strategy: str | None = Field(
default=None,
json_schema_extra={
"description": "Set to `no` to skip evaluation, `epoch` at end of each epoch, leave empty to infer from `eval_steps`"
},
)
save_steps: int | float | None = Field(
default=None,
json_schema_extra={
"description": "Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps"
},
)
saves_per_epoch: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of times per epoch to save a checkpoint, mutually exclusive with save_steps"
},
)
save_strategy: str | None = Field(
default=None,
json_schema_extra={
"description": "Set to `no` to skip checkpoint saves, `epoch` at end of each epoch, `best` when better result is achieved, leave empty to infer from `save_steps`"
},
)
save_total_limit: int | None = Field(
default=None, json_schema_extra={"description": "Checkpoints saved at a time"}
)
save_first_step: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to checkpoint a model after the first step of training. Defaults to False."
},
)
logging_steps: int | None = Field(
default=None, json_schema_extra={"description": "Logging frequency"}
)
early_stopping_patience: int | None = Field(
default=None,
json_schema_extra={
"description": "Stop training after this many evaluation losses have increased in a row. https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback"
},
)
load_best_model_at_end: bool | None = False
save_only_model: bool | None = Field(
default=False,
json_schema_extra={
"description": "Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints."
},
)
use_tensorboard: bool | None = Field(
default=None, json_schema_extra={"description": "Use tensorboard for logging"}
)
profiler_steps: int | None = Field(
default=None,
json_schema_extra={
"description": "Enable the pytorch profiler to capture the first N steps of training to the output_dir. see https://pytorch.org/blog/understanding-gpu-memory-1/ for more information. Snapshots can be visualized @ https://pytorch.org/memory_viz"
},
)
profiler_steps_start: int | None = Field(
default=0,
json_schema_extra={
"description": "Which step to start the profiler at. Useful for only capturing a few steps mid-run."
},
)
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
},
)
neftune_noise_alpha: float | None = Field(
default=None,
json_schema_extra={
"description": "NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings. Currently only supported on Llama and Mistral"
},
)
orpo_alpha: float | None = Field(
default=None,
json_schema_extra={
"description": "Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping."
},
)
rpo_alpha: float | None = Field(
default=None,
json_schema_extra={
"description": "Weighting of NLL term in loss from RPO paper"
},
)
simpo_gamma: float | None = Field(
default=None,
json_schema_extra={"description": "Target reward margin for the SimPO loss"},
)
cpo_alpha: float | None = Field(
default=None, json_schema_extra={"description": "Weight of the BC regularizer"}
)
kto_desirable_weight: float | None = Field(
default=None,
json_schema_extra={"description": "Factor for desirable loss term in KTO loss"},
)
kto_undesirable_weight: float | None = Field(
default=None,
json_schema_extra={
"description": "Factor for undesirable loss term in KTO loss"
},
)
rl_beta: float | None = Field(
default=None,
json_schema_extra={"description": "The beta parameter for the RL training"},
)
max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = Field(
default=None,
json_schema_extra={
"description": "Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model."
},
)
gpu_memory_limit: int | str | None = Field(
default=None,
json_schema_extra={
"description": "Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset"
},
)
low_cpu_mem_usage: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use low_cpu_mem_usage"},
)
chat_template: (
ChatTemplate
| Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")]
) | None = Field(
default=None,
json_schema_extra={
"description": "The name of the chat template to use for training, following values are supported: tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py. tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. The selected chat template will be saved to the tokenizer_config.json for easier inferencing"
},
)
chat_template_jinja: str | None = Field(
default=None,
json_schema_extra={
"description": "Custom jinja template or path to jinja file for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
},
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
json_schema_extra={
"description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template."
},
)
eot_tokens: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "Custom EOT (End-of-Turn) tokens to mask/unmask during training. These tokens mark the boundaries between conversation turns. For example: ['/INST', '</s>', '[/SYSTEM_PROMPT]']. If not specified, defaults to just the model's eos_token. This is useful for templates that use multiple delimiter tokens."
},
)
default_system_message: str | None = Field(
default=None,
json_schema_extra={
"description": "Changes the default system message. Currently only supports chatml."
},
)
fix_untrained_tokens: int | list[int] | None = None
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None
total_num_tokens: int | None = Field(
default=None,
json_schema_extra={"description": "Total number of tokens - internal use"},
)
total_supervised_tokens: int | None = None
sample_packing_eff_est: float | None = Field(
default=None,
json_schema_extra={
"description": "You can set these packing optimizations AFTER starting a training at least once. The trainer will provide recommended values for these values."
},
)
axolotl_config_path: str | None = None
is_falcon_derived_model: bool | None = Field(
default=None,
json_schema_extra={
"description": "Internal use only - Used to identify which the model is based on"
},
)
is_llama_derived_model: bool | None = Field(
default=None,
json_schema_extra={
"description": "Internal use only - Used to identify which the model is based on"
},
)
is_mistral_derived_model: bool | None = Field(
default=None,
json_schema_extra={
"description": "Internal use only - Used to identify which the model is based on. Please note that if you set this to true, `padding_side` will be set to 'left' by default"
},
)
is_qwen_derived_model: bool | None = Field(
default=None,
json_schema_extra={
"description": "Internal use only - Used to identify which the model is based on"
},
)
plugins: list[str] | None = Field(
default=None,
json_schema_extra={
"description": "Add plugins to extend the pipeline. See `src/axolotl/integrations` for the available plugins or doc below for more details. https://docs.axolotl.ai/docs/custom_integrations.html"
},
)
@field_serializer("datasets")
def datasets_serializer(
self, ds_configs: list[DatasetConfig] | None
) -> list[dict[str, Any]] | None:
if ds_configs:
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate GPU capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after")
def check_bf16(self):
if self.capabilities.bf16:
if not self.bf16 and not self.bfloat16:
LOG.info(
"bf16 support detected, but not enabled for this configuration."
)
else:
if (
not self.merge_lora
and not self.is_preprocess
and (self.bf16 is True or self.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
return data
# pylint: disable=duplicate-code
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
# pylint: disable=duplicate-code
@model_validator(mode="before")
@classmethod
def check_multigpu_lora_kernels(cls, data):
if (
data.get("lora_mlp_kernel")
or data.get("lora_qkv_kernel")
or data.get("lora_o_kernel")
):
capabilities = data.get("capabilities")
is_fsdp = data.get("fsdp_config") is not None
is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
if capabilities and capabilities.get("n_gpu", 0) > 1 and not is_fsdp2:
if is_fsdp:
raise ValueError(
"lora_mlp_kernel, lora_qkv_kernel, and lora_o_kernel are not compatible with FSDP1."
)
return data
@model_validator(mode="before")
@classmethod
def check_auto_enable_lora_kernels(cls, data):
# Only proceed if using LoRA or QLoRA adapter
if data.get("rl"):
# RL trainers not tested so don't enable kernels by default
return data
if data.get("adapter") in ["lora", "qlora"]:
# Skip if already set, using unsloth optimizations, or using 8-bit
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
kernel_fields = ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
if (
any(data.get(k) is not None for k in kernel_fields)
or any(data.get(k) for k in unsloth_fields)
or data.get("adapter") == "lora"
and data.get("load_in_8bit")
):
return data
# Skip if dropout is not 0, as auto enabling it would just disable it during runtime patch checks
if data.get("lora_dropout") != 0:
return data
# Check multi-GPU compatibility
capabilities = data.get("capabilities")
is_multi_gpu = capabilities and capabilities.get("n_gpu", 0) > 1
is_fsdp = data.get("fsdp_config") is not None
is_fsdp2 = is_fsdp and str(data.get("fsdp_version")) == "2"
if (
not is_multi_gpu
or (is_multi_gpu and not is_fsdp)
or (is_multi_gpu and is_fsdp2)
):
# Auto-enable kernels if not explicitly set by user
if data.get("lora_mlp_kernel") is None:
data["lora_mlp_kernel"] = True
if data.get("lora_qkv_kernel") is None:
data["lora_qkv_kernel"] = True
if data.get("lora_o_kernel") is None:
data["lora_o_kernel"] = True
LOG.warning(
"Auto-enabling LoRA kernel optimizations for faster training. "
+ "Please explicitly set `lora_*_kernel` config values to `false` to disable. "
+ "See https://docs.axolotl.ai/docs/lora_optims.html for more info."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data
@model_validator(mode="before")
@classmethod
def check_flex_torch_version(cls, data):
if (data.get("flex_attention") is not None) and (data.get("flex_attention")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError(
"Flex attention is not supported on torch version < 2.6.0"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):
if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"):
if version.parse(
env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True
else:
data["torch_compile"] = False
else:
data["torch_compile"] = False
return data
@model_validator(mode="before")
@classmethod
def check_beta_and_trl_beta_match(cls, data):
if data.get("beta") and data.get("trl", {}).get("beta"):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.6.0"):
LOG.warning(
f"torch=={torch_version} not be supported. Please upgrade to torch>=2.6.0."
)
return self
@model_validator(mode="before")
@classmethod
def check_qat_config(cls, data):
qat_cfg = data.get("qat", {})
if not qat_cfg:
return data
if data.get("peft"):
raise ValueError("QAT and PEFT cannot be used together.")
if data.get("load_in_8bit"):
raise ValueError("QAT and load_in_8bit cannot be used together.")
if data.get("load_in_4bit"):
raise ValueError("QAT and load_in_4bit cannot be used together.")
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.6.0"):
raise ValueError("QAT is not supported on torch version < 2.6.0")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_torch_version(cls, data):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if data.get("fsdp_config") and str(data.get("fsdp_version")) == "2":
if version.parse(torch_version) < version.parse("2.7.0"):
raise ValueError("FSDP2 is not supported on torch version < 2.7.0")
return data
@model_validator(mode="before")
@classmethod
def default_dataloader_opts(cls, data):
if (
data.get("dataloader_num_workers") is None
and data.get("dataloader_pin_memory") is None
and data.get("dataloader_prefetch_factor") is None
):
data["dataloader_num_workers"] = data.get("capabilities").get("n_gpu", 1)
data["dataloader_pin_memory"] = True
data["dataloader_prefetch_factor"] = 256
return data
@model_validator(mode="before")
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
data["dataset_processes"] = get_default_process_count()
return data