diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 1138d99f1..7ddeccff5 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -20,9 +20,11 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.11' - - name: install dependencies + - name: Install dependencies run: | - python3 -m pip install jupyter + python3 -m pip install jupyter quartodoc + - name: Build autodoc + run: quartodoc build - name: Publish to GitHub Pages (and render) uses: quarto-dev/quarto-actions/publish@v2 with: diff --git a/.gitignore b/.gitignore index 7b604d88c..40084b408 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,10 @@ prepared-datasets/ submit.sh *.out* +# Quartodoc generated files +objects.json +site_libs/ + typings/ out/ diff --git a/README.md b/README.md index 343816aff..ed2c9e6b1 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github - [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html) - [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html) - [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html) +- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation - [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions ## 🤝 Getting Help diff --git a/_quarto.yml b/_quarto.yml index 943ed5293..c564fb0dd 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -1,6 +1,178 @@ project: type: website +quartodoc: + dir: docs/api + package: axolotl + title: API Reference + parser: google + + sections: + - title: Core + desc: Core functionality for training + contents: + - train + - evaluate + - datasets + - convert + - prompt_tokenizers + - logging_config + - core.trainer_builder + - core.training_args + - core.chat.messages + - core.chat.format.chatml + - core.chat.format.llama3x + - core.chat.format.shared + - core.datasets.chat + - core.datasets.transforms.chat_builder + - title: CLI + desc: Command-line interface + contents: + - cli.main + - cli.train + - cli.evaluate + - cli.args + - cli.checks + - cli.config + - cli.inference + - cli.merge_lora + - cli.merge_sharded_fsdp_weights + - cli.preprocess + - cli.sweeps + - cli.utils + - cli.cloud.base + - cli.cloud.modal_ + - title: Trainers + desc: Training implementations + contents: + - core.trainers.base + - core.trainers.trl + - core.trainers.dpo.trainer + - core.trainers.grpo.trainer + - title: Prompt Strategies + desc: Prompt formatting strategies + contents: + - prompt_strategies.base + - prompt_strategies.chat_template + - prompt_strategies.alpaca_chat + - prompt_strategies.alpaca_instruct + - prompt_strategies.alpaca_w_system + - prompt_strategies.user_defined + - prompt_strategies.llama2_chat + - prompt_strategies.completion + - prompt_strategies.input_output + - prompt_strategies.stepwise_supervised + - prompt_strategies.metharme + - prompt_strategies.orcamini + - prompt_strategies.pygmalion + - prompt_strategies.messages.chat + - prompt_strategies.dpo.chat_template + - prompt_strategies.dpo.llama3 + - prompt_strategies.dpo.chatml + - prompt_strategies.dpo.zephyr + - prompt_strategies.dpo.user_defined + - prompt_strategies.dpo.passthrough + - prompt_strategies.kto.llama3 + - prompt_strategies.kto.chatml + - prompt_strategies.kto.user_defined + - prompt_strategies.orpo.chat_template + - prompt_strategies.bradley_terry.llama3 + - title: Kernels + desc: Low-level performance optimizations + contents: + - kernels.lora + - kernels.geglu + - kernels.swiglu + - kernels.quantize + - kernels.utils + - title: MonkeyPatches + desc: Runtime patches for model optimizations + contents: + - monkeypatch.llama_attn_hijack_flash + - monkeypatch.llama_attn_hijack_xformers + - monkeypatch.mistral_attn_hijack_flash + - monkeypatch.multipack + - monkeypatch.relora + - monkeypatch.llama_expand_mask + - monkeypatch.lora_kernels + - monkeypatch.utils + - monkeypatch.btlm_attn_hijack_flash + - monkeypatch.llama_patch_multipack + - monkeypatch.stablelm_attn_hijack_flash + - monkeypatch.trainer_fsdp_optim + - monkeypatch.transformers_fa_utils + - monkeypatch.unsloth_ + - monkeypatch.attention.mllama + - monkeypatch.data.batch_dataset_fetcher + - monkeypatch.mixtral + - title: Utils + desc: Utility functions + contents: + - utils.models + - utils.tokenization + - utils.chat_templates + - utils.lora + - utils.lora_embeddings + - utils.model_shard_quant + - utils.bench + - utils.freeze + - utils.trainer + - utils.schedulers + - utils.distributed + - utils.dict + - utils.optimizers.adopt + - utils.data.pretraining + - utils.data.sft + - utils.gradient_checkpointing.unsloth + - title: Schemas + desc: Pydantic data models for Axolotl config + contents: + - utils.schemas.config + - utils.schemas.model + - utils.schemas.training + - utils.schemas.datasets + - utils.schemas.peft + - utils.schemas.trl + - utils.schemas.integrations + - utils.schemas.enums + - utils.schemas.utils + - title: Integrations + desc: Third-party integrations and extensions + contents: + - integrations.base + - integrations.cut_cross_entropy.args + - integrations.grokfast.optimizer + - integrations.kd.trainer + - integrations.liger.args + - integrations.lm_eval.args + - integrations.spectrum.args + - title: Common + desc: Common utilities and shared functionality + contents: + - common.architectures + - common.const + - common.datasets + - title: Models + desc: Custom model implementations + contents: + - models.mamba.modeling_mamba + - title: Data Processing + desc: Data processing utilities + contents: + - utils.collators.core + - utils.collators.batching + - utils.collators.mamba + - utils.collators.mm_chat + - utils.samplers.multipack + - title: Callbacks + desc: Training callbacks + contents: + - utils.callbacks.perplexity + - utils.callbacks.profiler + - utils.callbacks.lisa + - utils.callbacks.mlflow_ + - utils.callbacks.comet_ + website: title: "Axolotl" description: "We make fine-tuning accessible, scalable, and fun" @@ -35,6 +207,8 @@ website: - docs/inference.qmd - docs/cli.qmd - docs/config.qmd + - text: "API Reference" + href: docs/api - section: "Dataset Formats" contents: docs/dataset-formats/* @@ -80,3 +254,22 @@ format: theme: darkly css: styles.css toc: true + # Enable better handling of line breaks in markdown + preserve-tabs: true + html-math-method: mathjax + # Improved markdown processing options + md-extensions: + - markdown_it + - def_list + - attr_list + - fenced_divs + - tables + - html_admonition + - lineblocks + - fancy_lists + # Control whitespace handling + whitespace: preserve + # Process newlines in paragraphs + wrap: preserve + # Better line break handling + preserve-linebreaks: true diff --git a/docs/.gitignore b/docs/.gitignore index 4c23a061f..6c3cb2070 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,4 @@ /.quarto/ _site/ +/api/*.qmd +/api/*.html diff --git a/docs/cli.qmd b/docs/cli.qmd index a57e54d9a..a3d5cf939 100644 --- a/docs/cli.qmd +++ b/docs/cli.qmd @@ -1,5 +1,5 @@ --- -title: "CLI Reference" +title: "Command Line Interface (CLI)" format: html: toc: true diff --git a/docs/dataset_preprocessing.qmd b/docs/dataset_preprocessing.qmd index 1075dc8e5..245723e67 100644 --- a/docs/dataset_preprocessing.qmd +++ b/docs/dataset_preprocessing.qmd @@ -6,7 +6,7 @@ description: How datasets are processed ## Overview Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside -the [dataset format](docs/dataset-formats) and prompt strategies to: +the [dataset format](dataset-formats) and prompt strategies to: - parse the dataset based on the *dataset format* - transform the dataset to how you would interact with the model based on the *prompt strategy* diff --git a/requirements-dev.txt b/requirements-dev.txt index 4b5df167b..9f523de54 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,3 +2,5 @@ pre-commit black mypy types-requests +quartodoc +jupyter diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index f46c8efe2..f53ea825a 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -25,7 +25,7 @@ from axolotl.cli.utils import ( ) from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import set_pytorch_cuda_alloc_conf -from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig +from axolotl.utils.schemas.config import AxolotlInputConfig @click.group() diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 19b947fb2..83bfb1c83 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -13,9 +13,7 @@ # limitations under the License. # pylint: disable=too-many-lines -""" -Builder for the training args and trainer -""" +"""Builder for the training args and trainer""" import abc import importlib @@ -85,8 +83,8 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator -from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers from axolotl.utils.models import ensure_dtype +from axolotl.utils.schemas.enums import CustomSupportedOptimizers try: import torch._dynamo # pylint: disable=ungrouped-imports diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 52e6363a2..f0c42830d 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -9,7 +9,7 @@ import logging from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer -from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig +from axolotl.utils.schemas.trl import TRLConfig LOG = logging.getLogger("axolotl") diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index 8d9ddc6ab..f3be9c2f4 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -8,6 +8,8 @@ from typing import Dict, Optional import torch from accelerate.logging import get_logger +from datasets import Dataset +from transformers.trainer import Trainer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta @@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate") def evaluate_dataset( - trainer, dataset, dataset_type: str, flash_optimum: bool = False + trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False ) -> Optional[Dict[str, float]]: - """Helper function to evaluate a single dataset safely. + """Helper function to evaluate a single dataset. Args: - trainer: The trainer instance - dataset: Dataset to evaluate - dataset_type: Type of dataset ('train' or 'eval') - flash_optimum: Whether to use flash optimum + trainer: The trainer instance. + dataset: Dataset to evaluate. + dataset_type: Type of dataset ('train' or 'eval'). + flash_optimum: Whether to use flash optimum. Returns: - Dictionary of metrics or None if dataset is None + Dictionary of metrics or None if dataset is None. """ if dataset is None: return None @@ -63,17 +65,14 @@ def evaluate_dataset( def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: """ - Evaluate a model on training and validation datasets + Evaluate a model on training and validation datasets. Args: cfg: Dictionary mapping `axolotl` config keys to values. dataset_meta: Dataset metadata containing training and evaluation datasets. Returns: - Tuple containing: - - The model (either PeftModel or PreTrainedModel) - - The tokenizer - - Dictionary of evaluation metrics + Dictionary mapping metric names to their values. """ # pylint: disable=duplicate-code # Enable expandable segments for cuda allocation to improve VRAM usage diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b4ffd6758..b443f228e 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -11,19 +11,17 @@ # the License. """ -module to handle merging the plugins' input arguments with the base configurations. +Module to handle merging the plugins' input arguments with the base configurations. -this was moved here to prevent circular imports +This was moved here to prevent circular imports. """ from typing import Any, Dict, List -from axolotl.utils.config.models.input.v0_4_1 import ( +from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, ) -from axolotl.utils.config.models.input.v0_4_1 import ( - AxolotlInputConfig as AxolotlInputConfigBase, -) +from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase def merge_input_args(): diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index af1d51a46..4266e0c99 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter from axolotl.utils.chat_templates import get_chat_template_from_config -from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig +from axolotl.utils.schemas.datasets import DatasetConfig # Configure the logger LOG = logging.getLogger("axolotl") diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index 9e4986dd4..f04bd7f0d 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates. """ from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template -from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic +from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic def default( diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 47c77619a..94f180ef4 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.callbacks.perplexity import Perplexity -from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import ( barrier, broadcast_dict, @@ -43,6 +42,7 @@ from axolotl.utils.distributed import ( is_main_process, zero_first, ) +from axolotl.utils.schemas.config import AxolotlInputConfig if TYPE_CHECKING: from axolotl.core.trainer_builder import AxolotlTrainingArguments diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index b7096eeab..136acc4a0 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available from axolotl.integrations.base import PluginManager from axolotl.integrations.config import merge_input_args from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.config.models.input.v0_4_1 import ( - AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, -) -from axolotl.utils.config.models.input.v0_4_1 import ( - AxolotlInputConfig as AxolotlInputConfigBase, -) -from axolotl.utils.config.models.input.v0_4_1 import ( - DPODataset, - KTODataset, - SFTDataset, -) from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model_config +from axolotl.utils.schemas.config import ( + AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, +) +from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase +from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset LOG = logging.getLogger("axolotl") diff --git a/src/axolotl/utils/config/models/input/next/__init__.py b/src/axolotl/utils/config/models/input/next/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/utils/config/models/input/__init__.py b/src/axolotl/utils/schemas/__init__.py similarity index 100% rename from src/axolotl/utils/config/models/input/__init__.py rename to src/axolotl/utils/schemas/__init__.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/schemas/config.py similarity index 52% rename from src/axolotl/utils/config/models/input/v0_4_1/__init__.py rename to src/axolotl/utils/schemas/config.py index f1c514a7c..7676a50a8 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/schemas/config.py @@ -1,11 +1,10 @@ -"""Module with Pydantic models for configuration.""" +"""Main Axolotl input configuration Pydantic models""" # pylint: disable=too-many-lines import logging import os -from enum import Enum -from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Literal from annotated_types import MinLen from packaging import version @@ -17,652 +16,41 @@ from pydantic import ( field_validator, model_validator, ) -from transformers import SchedulerType -from transformers.training_args import OptimizerNames from transformers.utils.import_utils import is_torch_npu_available -from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities +from axolotl.utils.schemas.datasets import ( + DatasetConfig, + DPODataset, + KTODataset, + PretrainingDataset, + SFTDataset, + StepwiseSupervisedDataset, +) +from axolotl.utils.schemas.deprecated import DeprecatedParameters, RemappedParameters +from axolotl.utils.schemas.enums import ChatTemplate, RLType +from axolotl.utils.schemas.integrations import ( + CometConfig, + GradioConfig, + LISAConfig, + MLFlowConfig, + RayConfig, + WandbConfig, +) +from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities +from axolotl.utils.schemas.model import ( + ModelInputConfig, + ModelOutputConfig, + SpecialTokensConfig, +) +from axolotl.utils.schemas.peft import LoraConfig, ReLoRAConfig +from axolotl.utils.schemas.training import HyperparametersConfig +from axolotl.utils.schemas.trl import TRLConfig -from .trl import TRLConfig - -LOG = logging.getLogger("axolotl.utils.config.models.input") +LOG = logging.getLogger(__name__) SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} -class RLType(str, Enum): - """RL trainer type configuration subset""" - - dpo = "dpo" # pylint: disable=invalid-name - grpo = "grpo" # pylint: disable=invalid-name - ipo = "ipo" # pylint: disable=invalid-name - orpo = "orpo" # pylint: disable=invalid-name - kto = "kto" # pylint: disable=invalid-name - simpo = "simpo" # pylint: disable=invalid-name - - -class ChatTemplate(str, Enum): - """Chat templates configuration subset""" - - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - mistral_v1 = "mistral_v1" # pylint: disable=invalid-name - mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name - mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - phi_35 = "phi_35" # pylint: disable=invalid-name - deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name - deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name - jamba = "jamba" # pylint: disable=invalid-name - jinja = "jinja" # pylint: disable=invalid-name - qwen_25 = "qwen_25" # pylint: disable=invalid-name - tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name - exaone = "exaone" # pylint: disable=invalid-name - metharme = "metharme" # pylint: disable=invalid-name - - -class CustomSupportedOptimizers(str, Enum): - """Custom supported optimizers""" - - optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name - ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name - ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name - ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name - adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name - muon = "muon" # pylint: disable=invalid-name - - -class DeprecatedParameters(BaseModel): - """configurations that are deprecated""" - - max_packed_sequence_len: Optional[int] = None - rope_scaling: Optional[Any] = None - noisy_embedding_alpha: Optional[float] = None - dpo_beta: Optional[float] = None - evaluation_strategy: Optional[str] = None - - @field_validator("max_packed_sequence_len") - @classmethod - def validate_max_packed_sequence_len(cls, max_packed_sequence_len): - if max_packed_sequence_len: - raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") - return max_packed_sequence_len - - @field_validator("rope_scaling") - @classmethod - def validate_rope_scaling(cls, rope_scaling): - if rope_scaling: - raise DeprecationWarning( - "`rope_scaling` is no longer supported, it should now be be a key under `model_config`" - ) - return rope_scaling - - @field_validator("noisy_embedding_alpha") - @classmethod - def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): - if noisy_embedding_alpha: - LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") - return noisy_embedding_alpha - - @field_validator("dpo_beta") - @classmethod - def validate_dpo_beta(cls, dpo_beta): - if dpo_beta is not None: - LOG.warning("dpo_beta is deprecated, use rl_beta instead") - return dpo_beta - - @field_validator("evaluation_strategy") - @classmethod - def validate_evaluation_strategy(cls, evaluation_strategy): - if evaluation_strategy is not None: - LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") - return evaluation_strategy - - -class RemappedParameters(BaseModel): - """parameters that have been remapped to other names""" - - overrides_of_model_config: Optional[Dict[str, Any]] = Field( - default=None, alias="model_config" - ) - overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field( - default=None, alias="model_kwargs" - ) - type_of_model: Optional[str] = Field(default=None, alias="model_type") - revision_of_model: Optional[str] = Field(default=None, alias="model_revision") - - -class PretrainingDataset(BaseModel): - """pretraining dataset configuration subset""" - - name: Optional[str] = None - path: Optional[str] = None - split: Optional[str] = "train" - text_column: Optional[str] = "text" - type: Optional[str] = "pretrain" - trust_remote_code: Optional[bool] = False - data_files: Optional[str] = None - skip: Optional[int] = None - - -class UserDefinedPrompterType(BaseModel): - """structure for user defined prompt types""" - - system_prompt: Optional[str] = None - system_format: Optional[str] = None - field_system: Optional[str] = None - field_instruction: Optional[str] = None - field_input: Optional[str] = None - field_output: Optional[str] = None - - format: Optional[str] = None - no_input_format: Optional[str] = None - field: Optional[str] = None - - -class LrGroup(BaseModel): - """Custom learning rate group configuration""" - - name: str - modules: List[str] - lr: float - - -class SFTDataset(BaseModel): - """SFT configuration subset""" - - path: Optional[str] = None - split: Optional[str] = None - type: Optional[Union[str, UserDefinedPrompterType]] = None - input_transform: Optional[str] = None - shards: Optional[int] = None - shards_idx: Optional[int] = None - preprocess_shards: Optional[int] = None - conversation: Optional[str] = None - # Do not make this too strict or it will break the validator to choose different dataset class - chat_template: Optional[ - Union[ - ChatTemplate, - str, - ] - ] = None - chat_template_jinja: Optional[str] = None - data_files: Optional[Union[str, List[str]]] = None - input_format: Optional[str] = None - name: Optional[str] = None - ds_type: Optional[str] = None - train_on_split: Optional[str] = None - field: Optional[str] = None - field_human: Optional[str] = None - field_model: Optional[str] = None - field_messages: Optional[str] = None - message_field_role: Optional[str] = ( - None # deprecated, use message_property_mappings - ) - message_field_content: Optional[str] = ( - None # deprecated, use message_property_mappings - ) - message_property_mappings: Optional[Dict[str, str]] = None - message_field_training: Optional[str] = None - message_field_training_detail: Optional[str] = None - logprobs_field: Optional[str] = None - temperature: Optional[float] = None - roles_to_train: Optional[List[str]] = None - train_on_eos: Optional[str] = None - roles: Optional[Dict[str, List[str]]] = None - drop_system_message: Optional[bool] = None - trust_remote_code: Optional[bool] = False - revision: Optional[str] = None - - @model_validator(mode="before") - @classmethod - def handle_legacy_message_fields(cls, data): - """Handle backwards compatibility between legacy message field mapping and new property mapping system.""" - return handle_legacy_message_fields_logic(data) - - @model_validator(mode="before") - @classmethod - def check_chat_template_config(cls, data): - if isinstance(data, BaseModel): - data = data.model_dump() - - # Set chat_template to tokenizer_default if not set - if data.get("type") == "chat_template" and not data.get("chat_template"): - data["chat_template"] = ChatTemplate.tokenizer_default - - # if chat_template is set to jinja, chat_template_jinja is required - if data.get("chat_template") == ChatTemplate.jinja and not data.get( - "chat_template_jinja" - ): - raise ValueError( - "chat_template_jinja is required when chat_template is set to jinja" - ) - - # If chat_template_jinja is set, set chat_template to jinja - if data.get("chat_template_jinja") and not data.get("chat_template"): - data["chat_template"] = ChatTemplate.jinja - - return data - - -class UserDefinedDPOType(BaseModel): - """User defined typing for DPO""" - - field_system: Optional[str] = None - field_prompt: Optional[str] = None - field_chosen: Optional[str] = None - field_rejected: Optional[str] = None - prompt_format: Optional[str] = None - chosen_format: Optional[str] = None - rejected_format: Optional[str] = None - - -class DPODataset(BaseModel): - """DPO configuration subset""" - - path: Optional[str] = None - split: Optional[str] = None - type: Optional[Union[UserDefinedDPOType, str]] = None - data_files: Optional[List[str]] = None - revision: Optional[str] = None - field_messages: Optional[str] = None - - -class StepwiseSupervisedDataset(BaseModel): - """Stepwise supervised dataset configuration subset""" - - path: Optional[str] = None - split: Optional[str] = None - data_files: Optional[List[str]] = None - revision: Optional[str] = None - step_separator: Optional[str] = None - max_completion_length: Optional[int] = None - train_on_last_step_only: Optional[bool] = None - - -class UserDefinedKTOType(BaseModel): - """User defined typing for KTO""" - - field_system: Optional[str] = None - field_prompt: Optional[str] = None - field_completion: Optional[str] = None - field_label: Optional[bool] = None - prompt_format: Optional[str] = None - completion_format: Optional[str] = None - - -class KTODataset(BaseModel): - """KTO configuration subset""" - - path: Optional[str] = None - split: Optional[str] = None - type: Optional[Union[UserDefinedKTOType, str]] = None - data_files: Optional[List[str]] = None - trust_remote_code: Optional[bool] = False - revision: Optional[str] = None - - -DatasetConfig = Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset] - - -class LoftQConfig(BaseModel): - """LoftQ configuration subset""" - - loftq_bits: int = Field( - default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} - ) - # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) - - -class PeftConfig(BaseModel): - """peftq configuration subset""" - - loftq_config: Optional[LoftQConfig] = None - - -class SpecialTokensConfig(BaseModel): - """Special tokens configuration subset""" - - bos_token: Optional[str] = None - eos_token: Optional[str] = None - pad_token: Optional[str] = None - unk_token: Optional[str] = None - additional_special_tokens: Optional[List[str]] = None - - -class LoraConfig(BaseModel): - """Peft / LoRA configuration subset""" - - load_in_8bit: Optional[bool] = Field(default=False) - load_in_4bit: Optional[bool] = Field(default=False) - - adapter: Optional[str] = None - lora_model_dir: Optional[str] = None - lora_r: Optional[int] = None - lora_alpha: Optional[int] = None - lora_fan_in_fan_out: Optional[bool] = None - lora_target_modules: Optional[Union[str, List[str]]] = None - lora_target_linear: Optional[bool] = None - lora_modules_to_save: Optional[List[str]] = None - lora_dropout: Optional[float] = 0.0 - peft_layers_to_transform: Optional[List[int]] = None - peft_layers_pattern: Optional[List[str]] = None - peft: Optional[PeftConfig] = None - peft_use_dora: Optional[bool] = None - peft_use_rslora: Optional[bool] = None - peft_layer_replication: Optional[List[Tuple[int, int]]] = None - peft_init_lora_weights: Optional[Union[bool, str]] = None - - qlora_sharded_model_loading: Optional[bool] = Field( - default=False, - json_schema_extra={ - "description": "load qlora model in sharded format for FSDP using answer.ai technique." - }, - ) - lora_on_cpu: Optional[bool] = None - gptq: Optional[bool] = None - bnb_config_kwargs: Optional[Dict[str, Any]] = None - - loraplus_lr_ratio: Optional[float] = Field( - default=None, - json_schema_extra={ - "description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4." - }, - ) - loraplus_lr_embedding: Optional[float] = Field( - default=1e-6, - json_schema_extra={ - "description": "loraplus learning rate for lora embedding layers." - }, - ) - - merge_lora: Optional[bool] = None - - @model_validator(mode="before") - @classmethod - def validate_adapter(cls, data): - if ( - not data.get("adapter") - and not data.get("inference") - and (data.get("load_in_8bit") or data.get("load_in_4bit")) - ): - raise ValueError( - "load_in_8bit and load_in_4bit are not supported without setting an adapter for training." - "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." - ) - return data - - @model_validator(mode="after") - def validate_qlora(self): - if self.adapter == "qlora": - if self.merge_lora: - # can't merge qlora if loaded in 8bit or 4bit - if self.load_in_8bit: - raise ValueError("Can't merge qlora if loaded in 8bit") - - if self.gptq: - raise ValueError("Can't merge qlora if gptq") - - if self.load_in_4bit: - raise ValueError("Can't merge qlora if loaded in 4bit") - - else: - if self.load_in_8bit: - raise ValueError("Can't load qlora in 8bit") - - if self.gptq: - raise ValueError("Can't load qlora if gptq") - - if not self.load_in_4bit: - raise ValueError("Require cfg.load_in_4bit to be True for qlora") - return self - - @field_validator("loraplus_lr_embedding") - @classmethod - def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): - if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str): - loraplus_lr_embedding = float(loraplus_lr_embedding) - return loraplus_lr_embedding - - @model_validator(mode="before") - @classmethod - def validate_lora_dropout(cls, data): - if data.get("adapter") is not None and data.get("lora_dropout") is None: - data["lora_dropout"] = 0.0 - return data - - -class ReLoRAConfig(BaseModel): - """ReLoRA configuration subset""" - - relora_steps: Optional[int] = None - relora_warmup_steps: Optional[int] = None - relora_anneal_steps: Optional[int] = None - relora_prune_ratio: Optional[float] = None - relora_cpu_offload: Optional[bool] = None - - -class ModelInputConfig(BaseModel): - """model to train on configuration subset""" - - model_config = {"protected_namespaces": ()} - - base_model: str - base_model_config: Optional[str] = None - cls_model_config: Optional[str] = None - tokenizer_config: Optional[str] = None - tokenizer_use_fast: Optional[bool] = None - tokenizer_legacy: Optional[bool] = None - tokenizer_type: Optional[str] = Field( - default=None, json_schema_extra={"description": "transformers tokenizer class"} - ) - processor_type: Optional[str] = Field( - default=None, json_schema_extra={"description": "transformers processor class"} - ) - trust_remote_code: Optional[bool] = None - - @field_validator("trust_remote_code") - @classmethod - def hint_trust_remote_code(cls, trust_remote_code): - if trust_remote_code: - LOG.warning( - "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." - ) - return trust_remote_code - - -class HyperparametersConfig(BaseModel): - """training hyperparams configuration subset""" - - gradient_accumulation_steps: Optional[int] = Field(default=1) - micro_batch_size: Optional[int] = Field( - default=1, - json_schema_extra={"description": "per gpu micro batch size for training"}, - ) - batch_size: Optional[int] = Field( - default=None, - json_schema_extra={ - "description": "Total batch size, we do not recommended setting this manually" - }, - ) - eval_batch_size: Optional[int] = Field( - default=None, - json_schema_extra={ - "description": "per gpu micro batch size for evals, defaults to value of micro_batch_size" - }, - ) - - auto_find_batch_size: Optional[bool] = None - - train_on_inputs: Optional[bool] = False - group_by_length: Optional[bool] = None - - learning_rate: Union[str, float] - embedding_lr: Optional[float] = None - embedding_lr_scale: Optional[float] = None - weight_decay: Optional[float] = 0.0 - optimizer: Optional[Union[OptimizerNames, CustomSupportedOptimizers]] = ( - OptimizerNames.ADAMW_TORCH_FUSED - ) - optim_args: Optional[Union[str, Dict[str, Any]]] = Field( - default=None, - json_schema_extra={"description": "Optional arguments to supply to optimizer."}, - ) - optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field( - default=None, - json_schema_extra={ - "description": "The target modules to optimize, i.e. the module names that you would like to train." - }, - ) - torchdistx_path: Optional[str] = None - lr_scheduler: Optional[ - Union[SchedulerType, Literal["one_cycle"], Literal["rex"]] - ] = SchedulerType.COSINE - lr_scheduler_kwargs: Optional[Dict[str, Any]] = None - lr_quadratic_warmup: Optional[bool] = None - cosine_min_lr_ratio: Optional[float] = None - cosine_constant_lr_ratio: Optional[float] = None - lr_div_factor: Optional[float] = None - lr_groups: Optional[List[LrGroup]] = None - - adam_epsilon: Optional[float] = None - adam_beta1: Optional[float] = None - adam_beta2: Optional[float] = None - max_grad_norm: Optional[float] = None - num_epochs: float = Field(default=1.0) - - @field_validator("batch_size") - @classmethod - def hint_batch_size_set(cls, batch_size): - if batch_size: - LOG.warning( - "%s\n%s", - "batch_size is not recommended. Please use gradient_accumulation_steps instead.", - "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", - ) - return batch_size - - @field_validator("learning_rate") - @classmethod - def convert_learning_rate(cls, learning_rate): - if learning_rate and isinstance(learning_rate, str): - learning_rate = float(learning_rate) - return learning_rate - - -class ModelOutputConfig(BaseModel): - """model save configuration subset""" - - output_dir: str = Field(default="./model-out") - hub_model_id: Optional[str] = None - hub_strategy: Optional[str] = None - save_safetensors: Optional[bool] = True - - -class MLFlowConfig(BaseModel): - """mlflow configuration subset""" - - use_mlflow: Optional[bool] = None - mlflow_tracking_uri: Optional[str] = None - mlflow_experiment_name: Optional[str] = None - mlflow_run_name: Optional[str] = None - hf_mlflow_log_artifacts: Optional[bool] = None - - -class LISAConfig(BaseModel): - """LISA options""" - - lisa_n_layers: Optional[int] = Field( - default=None, - json_schema_extra={"description": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = Field( - default=None, - json_schema_extra={"description": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = Field( - default="model.layers", - json_schema_extra={"description": "path under the model to access the layers"}, - ) - - -class WandbConfig(BaseModel): - """wandb configuration subset""" - - use_wandb: Optional[bool] = None - wandb_name: Optional[str] = None - wandb_run_id: Optional[str] = None - wandb_mode: Optional[str] = None - wandb_project: Optional[str] = None - wandb_entity: Optional[str] = None - wandb_watch: Optional[str] = None - wandb_log_model: Optional[str] = None - - @model_validator(mode="before") - @classmethod - def check_wandb_run(cls, data): - if data.get("wandb_run_id") and not data.get("wandb_name"): - data["wandb_name"] = data.get("wandb_run_id") - - LOG.warning( - "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." - ) - - return data - - -class CometConfig(BaseModel): - """Comet configuration subset""" - - use_comet: Optional[bool] = None - comet_api_key: Optional[str] = None - comet_workspace: Optional[str] = None - comet_project_name: Optional[str] = None - comet_experiment_key: Optional[str] = None - comet_mode: Optional[str] = None - comet_online: Optional[bool] = None - comet_experiment_config: Optional[Dict[str, Any]] = None - - -class GradioConfig(BaseModel): - """Gradio configuration subset""" - - gradio_title: Optional[str] = None - gradio_share: Optional[bool] = None - gradio_server_name: Optional[str] = None - gradio_server_port: Optional[int] = None - gradio_max_new_tokens: Optional[int] = None - gradio_temperature: Optional[float] = None - - -class RayConfig(BaseModel): - """Ray launcher configuration subset""" - - use_ray: bool = Field(default=False) - ray_run_name: Optional[str] = Field( - default=None, - json_schema_extra={ - "help": "The training results will be saved at `saves/ray_run_name`." - }, - ) - ray_num_workers: int = Field( - default=1, - json_schema_extra={ - "help": "The number of workers for Ray training. Default is 1 worker." - }, - ) - resources_per_worker: dict = Field( - default_factory=lambda: {"GPU": 1}, - json_schema_extra={ - "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." - }, - ) - - # pylint: disable=too-many-public-methods,too-many-ancestors class AxolotlInputConfig( ModelInputConfig, @@ -680,252 +68,250 @@ class AxolotlInputConfig( DeprecatedParameters, BaseModel, ): - """wrapper of all config options""" + """Wrapper of all config options""" model_config = {"populate_by_name": True} - strict: Optional[bool] = Field(default=False) - resume_from_checkpoint: Optional[str] = None - auto_resume_from_checkpoints: Optional[bool] = None - resize_token_embeddings_to_32x: Optional[bool] = None - mean_resizing_embeddings: Optional[bool] = False + strict: bool | None = Field(default=False) + resume_from_checkpoint: str | None = None + auto_resume_from_checkpoints: bool | None = None + resize_token_embeddings_to_32x: bool | None = None + mean_resizing_embeddings: bool | None = False # optionally shrink the embeddings when the tokenizer vocab size is smaller - shrink_embeddings: Optional[bool] = None + shrink_embeddings: bool | None = None - rl: Optional[RLType] = None - trl: Optional[TRLConfig] = Field( + rl: RLType | None = None + trl: TRLConfig | None = Field( default_factory=lambda: TRLConfig(), # pylint: disable=unnecessary-lambda ) - reward_model: Optional[bool] = None - process_reward_model: Optional[bool] = None - num_labels: Optional[int] = None - dpo_use_weighting: Optional[bool] = ( - None # whether to use weighting in DPO trainer. If none, default is false in the trainer. - ) - dpo_use_logits_to_keep: Optional[bool] = None + reward_model: bool | None = None + process_reward_model: bool | None = None + num_labels: int | None = None + # Whether to use weighting in DPO trainer. + # If `None`, default is `False` in the trainer. + dpo_use_weighting: bool | None = None + dpo_use_logits_to_keep: bool | None = None - datasets: Optional[ + datasets: ( Annotated[ - list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]], + list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], MinLen(1), ] - ] = None + | None + ) = None - test_datasets: Optional[ + test_datasets: ( Annotated[ - list[Union[SFTDataset, DPODataset, KTODataset, StepwiseSupervisedDataset]], + list[SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset], MinLen(1), ] - ] = None - shuffle_merged_datasets: Optional[bool] = True - dataset_prepared_path: Optional[str] = None - dataset_shard_num: Optional[int] = None - dataset_shard_idx: Optional[int] = None - skip_prepare_dataset: Optional[bool] = False + | None + ) = None + shuffle_merged_datasets: bool | None = True + dataset_prepared_path: str | None = None + dataset_shard_num: int | None = None + dataset_shard_idx: int | None = None + skip_prepare_dataset: bool | None = False - pretraining_dataset: Optional[ - Annotated[list[Union[PretrainingDataset, SFTDataset]], MinLen(1)] - ] = Field( + pretraining_dataset: ( + Annotated[list[PretrainingDataset | SFTDataset], MinLen(1)] | None + ) = Field( default=None, json_schema_extra={"description": "streaming dataset to use for pretraining"}, ) - dataset_processes: Optional[int] = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] - dataset_exact_deduplication: Optional[bool] = None - dataset_keep_in_memory: Optional[bool] = None - dataloader_pin_memory: Optional[bool] = None - dataloader_num_workers: Optional[int] = None - dataloader_prefetch_factor: Optional[int] = None - dataloader_drop_last: Optional[bool] = None + dataset_processes: int | None = Field(default=min(32, os.cpu_count())) # type: ignore[type-var] + dataset_exact_deduplication: bool | None = None + dataset_keep_in_memory: bool | None = None + dataloader_pin_memory: bool | None = None + dataloader_num_workers: int | None = None + dataloader_prefetch_factor: int | None = None + dataloader_drop_last: bool | None = None - accelerator_config: Optional[Dict[str, Any]] = None + accelerator_config: dict[str, Any] | None = None - remove_unused_columns: Optional[bool] = None + remove_unused_columns: bool | None = None - push_dataset_to_hub: Optional[str] = None - hf_use_auth_token: Optional[bool] = None + push_dataset_to_hub: str | None = None + hf_use_auth_token: bool | None = None - device: Optional[Any] = None - device_map: Optional[Any] = None - world_size: Optional[int] = None - local_rank: Optional[int] = None - ddp: Optional[bool] = None + device: Any | None = None + device_map: Any | None = None + world_size: int | None = None + local_rank: int | None = None + ddp: bool | None = None - seed: Optional[int] = None - ddp_timeout: Optional[int] = None - ddp_bucket_cap_mb: Optional[int] = None - ddp_broadcast_buffers: Optional[bool] = None - ddp_find_unused_parameters: Optional[bool] = None + seed: int | None = None + ddp_timeout: int | None = None + ddp_bucket_cap_mb: int | None = None + ddp_broadcast_buffers: bool | None = None + ddp_find_unused_parameters: bool | None = None - eval_table_size: Optional[int] = None - eval_max_new_tokens: Optional[int] = None - do_causal_lm_eval: Optional[bool] = None - eval_causal_lm_metrics: Optional[List[str]] = None - do_bench_eval: Optional[bool] = None - bench_dataset: Optional[str] = None - bench_split: Optional[str] = None - metric_for_best_model: Optional[str] = None - greater_is_better: Optional[bool] = None + eval_table_size: int | None = None + eval_max_new_tokens: int | None = None + do_causal_lm_eval: bool | None = None + eval_causal_lm_metrics: list[str] | None = None + do_bench_eval: bool | None = None + bench_dataset: str | None = None + bench_split: str | None = None + metric_for_best_model: str | None = None + greater_is_better: bool | None = None - loss_watchdog_threshold: Optional[float] = None - loss_watchdog_patience: Optional[int] = None + loss_watchdog_threshold: float | None = None + loss_watchdog_patience: int | None = None - gc_steps: Optional[int] = None + gc_steps: int | None = None - bf16: Optional[Union[Literal["auto"], bool]] = "auto" - fp16: Optional[bool] = None - bfloat16: Optional[bool] = None # for non-AMP cases - float16: Optional[bool] = None # for non-AMP cases - tf32: Optional[bool] = None - float32: Optional[bool] = None + bf16: Literal["auto"] | bool | None = "auto" + fp16: bool | None = None + bfloat16: bool | None = None # for non-AMP cases + float16: bool | None = None # for non-AMP cases + tf32: bool | None = None + float32: bool | None = None - # torch_dtype: Optional[torch.dtype] + # torch_dtype: torch.dtype | None - gradient_checkpointing: Optional[Union[Literal["unsloth", "offload"], bool]] = ( - Field(default=False) + gradient_checkpointing: Literal["unsloth", "offload"] | bool | None = Field( + default=False ) - gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None + gradient_checkpointing_kwargs: dict[str, Any] | None = None - unfrozen_parameters: Optional[List[str]] = None + unfrozen_parameters: list[str] | None = None sequence_len: int = Field(default=512) - min_sample_len: Optional[int] = None + min_sample_len: int | None = None max_prompt_len: int = Field( default=512, json_schema_extra={"description": "maximum prompt length for RL training"}, ) - sample_packing: Optional[bool] = None - sample_packing_group_size: Optional[int] = 100_000 - sample_packing_bin_size: Optional[int] = 200 - eval_sample_packing: Optional[bool] = None - pad_to_sequence_len: Optional[bool] = None - curriculum_sampling: Optional[bool] = None - multipack_real_batches: Optional[bool] = None - pretraining_sample_concatenation: Optional[bool] = Field( + sample_packing: bool | None = None + sample_packing_group_size: int | None = 100_000 + sample_packing_bin_size: int | None = 200 + eval_sample_packing: bool | None = None + pad_to_sequence_len: bool | None = None + curriculum_sampling: bool | None = None + multipack_real_batches: bool | None = None + pretraining_sample_concatenation: bool | None = Field( default=None, json_schema_extra={ "description": "whether to soft pack/concatenate samples during pretraining", }, ) - batch_flattening: Optional[Union[Literal["auto"], bool]] = None + batch_flattening: Literal["auto"] | bool | None = None # for PoSE context length extension - use_pose: Optional[bool] = None - pose_split_on_token_ids: Optional[List[int]] = None - pose_max_context_len: Optional[int] = None - pose_num_chunks: Optional[int] = None + use_pose: bool | None = None + pose_split_on_token_ids: list[int] | None = None + pose_max_context_len: int | None = None + pose_num_chunks: int | None = None - pretrain_multipack_buffer_size: Optional[int] = 10_000 - pretrain_multipack_attn: Optional[bool] = Field( + pretrain_multipack_buffer_size: int | None = 10_000 + pretrain_multipack_attn: bool | None = Field( default=True, json_schema_extra={ "description": "whether to prevent cross attention for packed sequences during pretraining", }, ) - xformers_attention: Optional[bool] = None - sdp_attention: Optional[bool] = None - s2_attention: Optional[bool] = None - flash_attention: Optional[bool] = None - flash_attn_cross_entropy: Optional[bool] = None - flash_attn_rms_norm: Optional[bool] = None - flash_attn_fuse_qkv: Optional[bool] = None - flash_attn_fuse_mlp: Optional[bool] = None - flash_optimum: Optional[bool] = None + xformers_attention: bool | None = None + sdp_attention: bool | None = None + s2_attention: bool | None = None + flash_attention: bool | None = None + flash_attn_cross_entropy: bool | None = None + flash_attn_rms_norm: bool | None = None + flash_attn_fuse_qkv: bool | None = None + flash_attn_fuse_mlp: bool | None = None + flash_optimum: bool | None = None - eager_attention: Optional[bool] = None + eager_attention: bool | None = None - unsloth_cross_entropy_loss: Optional[bool] = None - unsloth_lora_mlp: Optional[bool] = None - unsloth_lora_qkv: Optional[bool] = None - unsloth_lora_o: Optional[bool] = None - unsloth_rms_norm: Optional[bool] = None - unsloth_rope: Optional[bool] = None + unsloth_cross_entropy_loss: bool | None = None + unsloth_lora_mlp: bool | None = None + unsloth_lora_qkv: bool | None = None + unsloth_lora_o: bool | None = None + unsloth_rms_norm: bool | None = None + unsloth_rope: bool | None = None - lora_mlp_kernel: Optional[bool] = None - lora_qkv_kernel: Optional[bool] = None - lora_o_kernel: Optional[bool] = None + lora_mlp_kernel: bool | None = None + lora_qkv_kernel: bool | None = None + lora_o_kernel: bool | None = None - deepspeed: Optional[Union[str, Dict[str, Any]]] = None - fsdp: Optional[List[str]] = None - fsdp_config: Optional[Dict[str, Any]] = None - fsdp_final_state_dict_type: Optional[ - Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] - ] = None + deepspeed: str | dict[str, Any] | None = None + fsdp: list[str] | None = None + fsdp_config: dict[str, Any] | None = None + fsdp_final_state_dict_type: ( + Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] | None + ) = None - val_set_size: Optional[float] = Field(default=0.0) + val_set_size: float | None = Field(default=0.0) - special_tokens: Optional[SpecialTokensConfig] = None - tokens: Optional[List[str]] = None - added_tokens_overrides: Optional[Dict[int, str]] = None + special_tokens: SpecialTokensConfig | None = None + tokens: list[str] | None = None + added_tokens_overrides: dict[int, str] | None = None - torch_compile: Optional[Union[Literal["auto"], bool]] = None - torch_compile_backend: Optional[str] = None - torch_compile_mode: Optional[ - Literal["default", "reduce-overhead", "max-autotune"] - ] = None - - max_steps: Optional[int] = None - warmup_steps: Optional[int] = None - warmup_ratio: Optional[float] = None - eval_steps: Optional[Union[int, float]] = None - evals_per_epoch: Optional[int] = None - eval_strategy: Optional[str] = None - save_steps: Optional[Union[int, float]] = None - saves_per_epoch: Optional[int] = None - save_strategy: Optional[str] = None - save_total_limit: Optional[int] = None - logging_steps: Optional[int] = None - early_stopping_patience: Optional[int] = None - load_best_model_at_end: Optional[bool] = False - save_only_model: Optional[bool] = False - use_tensorboard: Optional[bool] = None - profiler_steps: Optional[int] = None - include_tokens_per_second: Optional[bool] = None - - neftune_noise_alpha: Optional[float] = None - - orpo_alpha: Optional[float] = None - rpo_alpha: Optional[float] = None - simpo_gamma: Optional[float] = None - cpo_alpha: Optional[float] = None - - kto_desirable_weight: Optional[float] = None - kto_undesirable_weight: Optional[float] = None - rl_beta: Optional[float] = None - - max_memory: Optional[Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]] = ( + torch_compile: Literal["auto"] | bool | None = None + torch_compile_backend: str | None = None + torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = ( None ) - gpu_memory_limit: Optional[Union[int, str]] = None - low_cpu_mem_usage: Optional[bool] = None - chat_template: Optional[ - Union[ - ChatTemplate, - Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], - ] - ] = None - chat_template_jinja: Optional[str] = None - default_system_message: Optional[str] = None + max_steps: int | None = None + warmup_steps: int | None = None + warmup_ratio: float | None = None + eval_steps: int | float | None = None + evals_per_epoch: int | None = None + eval_strategy: str | None = None + save_steps: int | float | None = None + saves_per_epoch: int | None = None + save_strategy: str | None = None + save_total_limit: int | None = None + logging_steps: int | None = None + early_stopping_patience: int | None = None + load_best_model_at_end: bool | None = False + save_only_model: bool | None = False + use_tensorboard: bool | None = None + profiler_steps: int | None = None + include_tokens_per_second: bool | None = None - fix_untrained_tokens: Optional[Union[int, List[int]]] = None + neftune_noise_alpha: float | None = None + + orpo_alpha: float | None = None + rpo_alpha: float | None = None + simpo_gamma: float | None = None + cpo_alpha: float | None = None + + kto_desirable_weight: float | None = None + kto_undesirable_weight: float | None = None + rl_beta: float | None = None + + max_memory: dict[int | Literal["cpu", "disk"], int | str] | None = None + gpu_memory_limit: int | str | None = None + low_cpu_mem_usage: bool | None = None + + chat_template: ( + ChatTemplate + | Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")] + ) | None = None + chat_template_jinja: str | None = None + default_system_message: str | None = None + + fix_untrained_tokens: int | list[int] | None = None # INTERNALS - document for now, generally not set externally - is_preprocess: Optional[bool] = None - preprocess_iterable: Optional[bool] = None + is_preprocess: bool | None = None + preprocess_iterable: bool | None = None - total_num_tokens: Optional[int] = None - total_supervised_tokens: Optional[int] = None - sample_packing_eff_est: Optional[float] = None - axolotl_config_path: Optional[str] = None + total_num_tokens: int | None = None + total_supervised_tokens: int | None = None + sample_packing_eff_est: float | None = None + axolotl_config_path: str | None = None - is_falcon_derived_model: Optional[bool] = Field(default=None) - is_llama_derived_model: Optional[bool] = Field(default=None) - is_mistral_derived_model: Optional[bool] = Field(default=None) - is_qwen_derived_model: Optional[bool] = Field(default=None) + is_falcon_derived_model: bool | None = Field(default=None) + is_llama_derived_model: bool | None = Field(default=None) + is_mistral_derived_model: bool | None = Field(default=None) + is_qwen_derived_model: bool | None = Field(default=None) - plugins: Optional[List[str]] = Field(default=None) + plugins: list[str] | None = Field(default=None) @field_validator("datasets", mode="before") @classmethod @@ -953,8 +339,8 @@ class AxolotlInputConfig( @field_serializer("datasets") def datasets_serializer( - self, ds_configs: Optional[List[DatasetConfig]] - ) -> Optional[List[Dict[str, Any]]]: + self, ds_configs: list[DatasetConfig] | None + ) -> list[dict[str, Any]] | None: if ds_configs: return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None @@ -1028,6 +414,7 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod + # pylint: disable=duplicate-code def check_chat_template_config(cls, data): # if chat_template is set to jinja, chat_template_jinja is required if data.get("chat_template") == ChatTemplate.jinja and not data.get( @@ -1850,77 +1237,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data["beta"] != data["trl"]["beta"]: raise ValueError("beta and trl.beta must match or one must be removed") return data - - -def handle_legacy_message_fields_logic(data: dict) -> dict: - """ - Handle backwards compatibility between legacy message field mapping and new property mapping system. - - Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options: - - message_field_role: Mapped to the role field - - message_field_content: Mapped to the content field - - The new system uses message_property_mappings to support arbitrary field mappings: - message_property_mappings: - role: source_role_field - content: source_content_field - additional_field: source_field - - Args: - data: Dictionary containing configuration data - - Returns: - Updated dictionary with message field mappings consolidated - - Raises: - ValueError: If there are conflicts between legacy and new mappings - """ - data = data.copy() # Create a copy to avoid modifying the original - - if data.get("message_property_mappings") is None: - data["message_property_mappings"] = {} - - # Check for conflicts and handle role - if "message_field_role" in data: - LOG.warning( - "message_field_role is deprecated, use message_property_mappings instead. " - f"Example: message_property_mappings: {{role: {data['message_field_role']}}}" - ) - if ( - "role" in data["message_property_mappings"] - and data["message_property_mappings"]["role"] != data["message_field_role"] - ): - raise ValueError( - f"Conflicting message role fields: message_field_role='{data['message_field_role']}' " - f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'" - ) - data["message_property_mappings"]["role"] = data["message_field_role"] or "role" - - del data["message_field_role"] - elif "role" not in data["message_property_mappings"]: - data["message_property_mappings"]["role"] = "role" - - # Check for conflicts and handle content - if "message_field_content" in data: - LOG.warning( - "message_field_content is deprecated, use message_property_mappings instead. " - f"Example: message_property_mappings: {{content: {data['message_field_content']}}}" - ) - if ( - "content" in data["message_property_mappings"] - and data["message_property_mappings"]["content"] - != data["message_field_content"] - ): - raise ValueError( - f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " - f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" - ) - data["message_property_mappings"]["content"] = ( - data["message_field_content"] or "content" - ) - - del data["message_field_content"] - elif "content" not in data["message_property_mappings"]: - data["message_property_mappings"]["content"] = "content" - - return data diff --git a/src/axolotl/utils/schemas/datasets.py b/src/axolotl/utils/schemas/datasets.py new file mode 100644 index 000000000..57de71da2 --- /dev/null +++ b/src/axolotl/utils/schemas/datasets.py @@ -0,0 +1,165 @@ +"""Pydantic models for datasets-related configuration""" + +from pydantic import BaseModel, model_validator + +from axolotl.utils.schemas.enums import ChatTemplate +from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic + + +class UserDefinedPrompterType(BaseModel): + """Structure for user defined prompt types""" + + system_prompt: str | None = None + system_format: str | None = None + field_system: str | None = None + field_instruction: str | None = None + field_input: str | None = None + field_output: str | None = None + + format: str | None = None + no_input_format: str | None = None + field: str | None = None + + +class SFTDataset(BaseModel): + """SFT configuration subset""" + + path: str | None = None + split: str | None = None + type: str | UserDefinedPrompterType | None = None + input_transform: str | None = None + shards: int | None = None + shards_idx: int | None = None + preprocess_shards: int | None = None + conversation: str | None = None + # Do not make this too strict or it will break the validator to choose different dataset class + chat_template: ChatTemplate | str | None = None + chat_template_jinja: str | None = None + data_files: str | list[str] | None = None + input_format: str | None = None + name: str | None = None + ds_type: str | None = None + train_on_split: str | None = None + field: str | None = None + field_human: str | None = None + field_model: str | None = None + field_messages: str | None = None + # deprecated, use message_property_mappings + message_field_role: str | None = None + # deprecated, use message_property_mappings + message_field_content: str | None = None + message_property_mappings: dict[str, str] | None = None + message_field_training: str | None = None + message_field_training_detail: str | None = None + logprobs_field: str | None = None + temperature: float | None = None + roles_to_train: list[str] | None = None + train_on_eos: str | None = None + roles: dict[str, list[str]] | None = None + drop_system_message: bool | None = None + trust_remote_code: bool | None = False + revision: str | None = None + + @model_validator(mode="before") + @classmethod + def handle_legacy_message_fields(cls, data): + """Handle backwards compatibility between legacy message field mapping and new property mapping system.""" + return handle_legacy_message_fields_logic(data) + + @model_validator(mode="before") + @classmethod + # pylint: disable=duplicate-code + def check_chat_template_config(cls, data): + if isinstance(data, BaseModel): + data = data.model_dump() + + # Set chat_template to tokenizer_default if not set + if data.get("type") == "chat_template" and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.tokenizer_default + + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + + +class PretrainingDataset(BaseModel): + """Pretraining dataset configuration subset""" + + name: str | None = None + path: str | None = None + split: str | None = "train" + text_column: str | None = "text" + type: str | None = "pretrain" + trust_remote_code: bool | None = False + data_files: str | None = None + skip: int | None = None + + +class UserDefinedDPOType(BaseModel): + """User defined typing for DPO""" + + field_system: str | None = None + field_prompt: str | None = None + field_chosen: str | None = None + field_rejected: str | None = None + prompt_format: str | None = None + chosen_format: str | None = None + rejected_format: str | None = None + + +class DPODataset(BaseModel): + """DPO configuration subset""" + + path: str | None = None + split: str | None = None + type: UserDefinedDPOType | str | None = None + data_files: list[str] | None = None + revision: str | None = None + field_messages: str | None = None + + +class StepwiseSupervisedDataset(BaseModel): + """Stepwise supervised dataset configuration subset""" + + path: str | None = None + split: str | None = None + data_files: list[str] | None = None + revision: str | None = None + step_separator: str | None = None + max_completion_length: int | None = None + train_on_last_step_only: bool | None = None + + +class UserDefinedKTOType(BaseModel): + """User defined typing for KTO""" + + field_system: str | None = None + field_prompt: str | None = None + field_completion: str | None = None + field_label: bool | None = None + prompt_format: str | None = None + completion_format: str | None = None + + +class KTODataset(BaseModel): + """KTO configuration subset""" + + path: str | None = None + split: str | None = None + type: UserDefinedKTOType | str | None = None + data_files: list[str] | None = None + trust_remote_code: bool | None = False + revision: str | None = None + + +DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset diff --git a/src/axolotl/utils/schemas/deprecated.py b/src/axolotl/utils/schemas/deprecated.py new file mode 100644 index 000000000..d42d6ff9e --- /dev/null +++ b/src/axolotl/utils/schemas/deprecated.py @@ -0,0 +1,68 @@ +"""Pydantic models for deprecated and remapped configuration parameters""" + +import logging +from typing import Any + +from pydantic import BaseModel, Field, field_validator + +LOG = logging.getLogger(__name__) + + +class DeprecatedParameters(BaseModel): + """configurations that are deprecated""" + + max_packed_sequence_len: int | None = None + rope_scaling: Any | None = None + noisy_embedding_alpha: float | None = None + dpo_beta: float | None = None + evaluation_strategy: str | None = None + + @field_validator("max_packed_sequence_len") + @classmethod + def validate_max_packed_sequence_len(cls, max_packed_sequence_len): + if max_packed_sequence_len: + raise DeprecationWarning("`max_packed_sequence_len` is no longer supported") + return max_packed_sequence_len + + @field_validator("rope_scaling") + @classmethod + def validate_rope_scaling(cls, rope_scaling): + if rope_scaling: + raise DeprecationWarning( + "`rope_scaling` is no longer supported, it should now be be a key under `model_config`" + ) + return rope_scaling + + @field_validator("noisy_embedding_alpha") + @classmethod + def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha): + if noisy_embedding_alpha: + LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha") + return noisy_embedding_alpha + + @field_validator("dpo_beta") + @classmethod + def validate_dpo_beta(cls, dpo_beta): + if dpo_beta is not None: + LOG.warning("dpo_beta is deprecated, use rl_beta instead") + return dpo_beta + + @field_validator("evaluation_strategy") + @classmethod + def validate_evaluation_strategy(cls, evaluation_strategy): + if evaluation_strategy is not None: + LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead") + return evaluation_strategy + + +class RemappedParameters(BaseModel): + """Parameters that have been remapped to other names""" + + overrides_of_model_config: dict[str, Any] | None = Field( + default=None, alias="model_config" + ) + overrides_of_model_kwargs: dict[str, Any] | None = Field( + default=None, alias="model_kwargs" + ) + type_of_model: str | None = Field(default=None, alias="model_type") + revision_of_model: str | None = Field(default=None, alias="model_revision") diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py new file mode 100644 index 000000000..f376aca5f --- /dev/null +++ b/src/axolotl/utils/schemas/enums.py @@ -0,0 +1,49 @@ +"""Enums for Axolotl input config""" + +from enum import Enum + + +class RLType(str, Enum): + """RL trainer type configuration subset""" + + dpo = "dpo" # pylint: disable=invalid-name + grpo = "grpo" # pylint: disable=invalid-name + ipo = "ipo" # pylint: disable=invalid-name + orpo = "orpo" # pylint: disable=invalid-name + kto = "kto" # pylint: disable=invalid-name + simpo = "simpo" # pylint: disable=invalid-name + + +class ChatTemplate(str, Enum): + """Chat templates configuration subset""" + + alpaca = "alpaca" # pylint: disable=invalid-name + chatml = "chatml" # pylint: disable=invalid-name + mistral_v1 = "mistral_v1" # pylint: disable=invalid-name + mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name + mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name + gemma = "gemma" # pylint: disable=invalid-name + cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name + phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name + deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name + deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name + jamba = "jamba" # pylint: disable=invalid-name + jinja = "jinja" # pylint: disable=invalid-name + qwen_25 = "qwen_25" # pylint: disable=invalid-name + tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name + exaone = "exaone" # pylint: disable=invalid-name + metharme = "metharme" # pylint: disable=invalid-name + + +class CustomSupportedOptimizers(str, Enum): + """Custom supported optimizers""" + + optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name + ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name + ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name + ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name + adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name + muon = "muon" # pylint: disable=invalid-name diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py new file mode 100644 index 000000000..9d8f9c190 --- /dev/null +++ b/src/axolotl/utils/schemas/integrations.py @@ -0,0 +1,108 @@ +"""Pydantic models for Axolotl integrations""" + +import logging +from typing import Any + +from pydantic import BaseModel, Field, model_validator + +LOG = logging.getLogger(__name__) + + +class MLFlowConfig(BaseModel): + """MLFlow configuration subset""" + + use_mlflow: bool | None = None + mlflow_tracking_uri: str | None = None + mlflow_experiment_name: str | None = None + mlflow_run_name: str | None = None + hf_mlflow_log_artifacts: bool | None = None + + +class LISAConfig(BaseModel): + """LISA configuration subset""" + + lisa_n_layers: int | None = Field( + default=None, + json_schema_extra={"description": "the number of activate layers in LISA"}, + ) + lisa_step_interval: int | None = Field( + default=None, + json_schema_extra={"description": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: str | None = Field( + default="model.layers", + json_schema_extra={"description": "path under the model to access the layers"}, + ) + + +class WandbConfig(BaseModel): + """Wandb configuration subset""" + + use_wandb: bool | None = None + wandb_name: str | None = None + wandb_run_id: str | None = None + wandb_mode: str | None = None + wandb_project: str | None = None + wandb_entity: str | None = None + wandb_watch: str | None = None + wandb_log_model: str | None = None + + @model_validator(mode="before") + @classmethod + def check_wandb_run(cls, data): + if data.get("wandb_run_id") and not data.get("wandb_name"): + data["wandb_name"] = data.get("wandb_run_id") + + LOG.warning( + "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead." + ) + + return data + + +class CometConfig(BaseModel): + """Comet configuration subset""" + + use_comet: bool | None = None + comet_api_key: str | None = None + comet_workspace: str | None = None + comet_project_name: str | None = None + comet_experiment_key: str | None = None + comet_mode: str | None = None + comet_online: bool | None = None + comet_experiment_config: dict[str, Any] | None = None + + +class GradioConfig(BaseModel): + """Gradio configuration subset""" + + gradio_title: str | None = None + gradio_share: bool | None = None + gradio_server_name: str | None = None + gradio_server_port: int | None = None + gradio_max_new_tokens: int | None = None + gradio_temperature: float | None = None + + +class RayConfig(BaseModel): + """Ray launcher configuration subset""" + + use_ray: bool = Field(default=False) + ray_run_name: str | None = Field( + default=None, + json_schema_extra={ + "help": "The training results will be saved at `saves/ray_run_name`." + }, + ) + ray_num_workers: int = Field( + default=1, + json_schema_extra={ + "help": "The number of workers for Ray training. Default is 1 worker." + }, + ) + resources_per_worker: dict = Field( + default_factory=lambda: {"GPU": 1}, + json_schema_extra={ + "help": "The resources per worker for Ray training. Default is to use 1 GPU per worker." + }, + ) diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/schemas/internal/__init__.py similarity index 100% rename from src/axolotl/utils/config/models/internals/__init__.py rename to src/axolotl/utils/schemas/internal/__init__.py diff --git a/src/axolotl/utils/schemas/model.py b/src/axolotl/utils/schemas/model.py new file mode 100644 index 000000000..5f1d26e84 --- /dev/null +++ b/src/axolotl/utils/schemas/model.py @@ -0,0 +1,55 @@ +"""Pydantic models for model input / output, etc. configuration""" + +import logging + +from pydantic import BaseModel, Field, field_validator + +LOG = logging.getLogger(__name__) + + +class ModelInputConfig(BaseModel): + """Model configuration subset""" + + model_config = {"protected_namespaces": ()} + + base_model: str + base_model_config: str | None = None + cls_model_config: str | None = None + tokenizer_config: str | None = None + tokenizer_use_fast: bool | None = None + tokenizer_legacy: bool | None = None + tokenizer_type: str | None = Field( + default=None, json_schema_extra={"description": "transformers tokenizer class"} + ) + processor_type: str | None = Field( + default=None, json_schema_extra={"description": "transformers processor class"} + ) + trust_remote_code: bool | None = None + + @field_validator("trust_remote_code") + @classmethod + def hint_trust_remote_code(cls, trust_remote_code): + if trust_remote_code: + LOG.warning( + "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." + ) + return trust_remote_code + + +class ModelOutputConfig(BaseModel): + """model save configuration subset""" + + output_dir: str = Field(default="./model-out") + hub_model_id: str | None = None + hub_strategy: str | None = None + save_safetensors: bool | None = True + + +class SpecialTokensConfig(BaseModel): + """Special tokens configuration subset""" + + bos_token: str | None = None + eos_token: str | None = None + pad_token: str | None = None + unk_token: str | None = None + additional_special_tokens: list[str] | None = None diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py new file mode 100644 index 000000000..5d408e1fe --- /dev/null +++ b/src/axolotl/utils/schemas/peft.py @@ -0,0 +1,132 @@ +"""Pydantic models for PEFT-related configuration""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator, model_validator + + +class LoftQConfig(BaseModel): + """LoftQ configuration subset""" + + loftq_bits: int = Field( + default=4, json_schema_extra={"description": "Quantization bits for LoftQ"} + ) + # loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"}) + + +class PeftConfig(BaseModel): + """peftq configuration subset""" + + loftq_config: LoftQConfig | None = None + + +class LoraConfig(BaseModel): + """Peft / LoRA configuration subset""" + + load_in_8bit: bool | None = Field(default=False) + load_in_4bit: bool | None = Field(default=False) + + adapter: str | None = None + lora_model_dir: str | None = None + lora_r: int | None = None + lora_alpha: int | None = None + lora_fan_in_fan_out: bool | None = None + lora_target_modules: str | list[str] | None = None + lora_target_linear: bool | None = None + lora_modules_to_save: list[str] | None = None + lora_dropout: float | None = 0.0 + peft_layers_to_transform: list[int] | None = None + peft_layers_pattern: list[str] | None = None + peft: PeftConfig | None = None + peft_use_dora: bool | None = None + peft_use_rslora: bool | None = None + peft_layer_replication: list[tuple[int, int]] | None = None + peft_init_lora_weights: bool | str | None = None + + qlora_sharded_model_loading: bool | None = Field( + default=False, + json_schema_extra={ + "description": "load qlora model in sharded format for FSDP using answer.ai technique." + }, + ) + lora_on_cpu: bool | None = None + gptq: bool | None = None + bnb_config_kwargs: dict[str, Any] | None = None + + loraplus_lr_ratio: float | None = Field( + default=None, + json_schema_extra={ + "description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4." + }, + ) + loraplus_lr_embedding: float | None = Field( + default=1e-6, + json_schema_extra={ + "description": "loraplus learning rate for lora embedding layers." + }, + ) + + merge_lora: bool | None = None + + @model_validator(mode="before") + @classmethod + def validate_adapter(cls, data): + if ( + not data.get("adapter") + and not data.get("inference") + and (data.get("load_in_8bit") or data.get("load_in_4bit")) + ): + raise ValueError( + "load_in_8bit and load_in_4bit are not supported without setting an adapter for training." + "If you want to full finetune, please turn off load_in_8bit and load_in_4bit." + ) + return data + + @model_validator(mode="after") + def validate_qlora(self): + if self.adapter == "qlora": + if self.merge_lora: + # can't merge qlora if loaded in 8bit or 4bit + if self.load_in_8bit: + raise ValueError("Can't merge qlora if loaded in 8bit") + + if self.gptq: + raise ValueError("Can't merge qlora if gptq") + + if self.load_in_4bit: + raise ValueError("Can't merge qlora if loaded in 4bit") + + else: + if self.load_in_8bit: + raise ValueError("Can't load qlora in 8bit") + + if self.gptq: + raise ValueError("Can't load qlora if gptq") + + if not self.load_in_4bit: + raise ValueError("Require cfg.load_in_4bit to be True for qlora") + return self + + @field_validator("loraplus_lr_embedding") + @classmethod + def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding): + if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str): + loraplus_lr_embedding = float(loraplus_lr_embedding) + return loraplus_lr_embedding + + @model_validator(mode="before") + @classmethod + def validate_lora_dropout(cls, data): + if data.get("adapter") is not None and data.get("lora_dropout") is None: + data["lora_dropout"] = 0.0 + return data + + +class ReLoRAConfig(BaseModel): + """ReLoRA configuration subset""" + + relora_steps: int | None = None + relora_warmup_steps: int | None = None + relora_anneal_steps: int | None = None + relora_prune_ratio: float | None = None + relora_cpu_offload: bool | None = None diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py new file mode 100644 index 000000000..2ab4b4286 --- /dev/null +++ b/src/axolotl/utils/schemas/training.py @@ -0,0 +1,99 @@ +"""Pydantic models for training hyperparameters""" + +import logging +from typing import Any, Literal + +from pydantic import BaseModel, Field, field_validator +from transformers import SchedulerType +from transformers.training_args import OptimizerNames + +from axolotl.utils.schemas.enums import CustomSupportedOptimizers + +LOG = logging.getLogger(__name__) + + +class LrGroup(BaseModel): + """Custom learning rate group configuration""" + + name: str + modules: list[str] + lr: float + + +class HyperparametersConfig(BaseModel): + """Training hyperparams configuration subset""" + + gradient_accumulation_steps: int | None = Field(default=1) + micro_batch_size: int | None = Field( + default=1, + json_schema_extra={"description": "per gpu micro batch size for training"}, + ) + batch_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "Total batch size, we do not recommended setting this manually" + }, + ) + eval_batch_size: int | None = Field( + default=None, + json_schema_extra={ + "description": "per gpu micro batch size for evals, defaults to value of micro_batch_size" + }, + ) + + auto_find_batch_size: bool | None = None + + train_on_inputs: bool | None = False + group_by_length: bool | None = None + + learning_rate: str | float + embedding_lr: float | None = None + embedding_lr_scale: float | None = None + weight_decay: float | None = 0.0 + optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = ( + OptimizerNames.ADAMW_TORCH_FUSED + ) + optim_args: (str | dict[str, Any]) | None = Field( + default=None, + json_schema_extra={"description": "Optional arguments to supply to optimizer."}, + ) + optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field( + default=None, + json_schema_extra={ + "description": "The target modules to optimize, i.e. the module names that you would like to train." + }, + ) + torchdistx_path: str | None = None + lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = ( + SchedulerType.COSINE + ) + lr_scheduler_kwargs: dict[str, Any] | None = None + lr_quadratic_warmup: bool | None = None + cosine_min_lr_ratio: float | None = None + cosine_constant_lr_ratio: float | None = None + lr_div_factor: float | None = None + lr_groups: list[LrGroup] | None = None + + adam_epsilon: float | None = None + adam_beta1: float | None = None + adam_beta2: float | None = None + max_grad_norm: float | None = None + num_epochs: float = Field(default=1.0) + + @field_validator("batch_size") + @classmethod + def hint_batch_size_set(cls, batch_size): + if batch_size: + LOG.warning( + "%s\n%s", + "batch_size is not recommended. Please use gradient_accumulation_steps instead.", + "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", + ) + return batch_size + + @field_validator("learning_rate") + @classmethod + def convert_learning_rate(cls, learning_rate): + if learning_rate and isinstance(learning_rate, str): + learning_rate = float(learning_rate) + return learning_rate diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/schemas/trl.py similarity index 76% rename from src/axolotl/utils/config/models/input/v0_4_1/trl.py rename to src/axolotl/utils/schemas/trl.py index f408acdba..60759769d 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/trl.py +++ b/src/axolotl/utils/schemas/trl.py @@ -1,8 +1,4 @@ -""" -GRPO specific configuration args -""" - -from typing import Optional +"""Pydantic models for TRL trainer configuration""" from pydantic import BaseModel, Field @@ -12,11 +8,11 @@ class TRLConfig(BaseModel): Input args for TRL. """ - beta: Optional[float] = Field( + beta: float | None = Field( default=None, json_schema_extra={"description": "Beta for RL training"}, ) - max_completion_length: Optional[int] = Field( + max_completion_length: int | None = Field( default=None, json_schema_extra={ "description": "Maximum length of the completion for RL training" @@ -25,50 +21,50 @@ class TRLConfig(BaseModel): # GRPO specific args # Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22 - use_vllm: Optional[bool] = Field( + use_vllm: bool | None = Field( default=False, json_schema_extra={"description": "Whether to use VLLM for RL training"}, ) - vllm_device: Optional[str] = Field( + vllm_device: str | None = Field( default="auto", json_schema_extra={"description": "Device to use for VLLM"}, ) - vllm_gpu_memory_utilization: Optional[float] = Field( + vllm_gpu_memory_utilization: float | None = Field( default=0.9, json_schema_extra={"description": "GPU memory utilization for VLLM"}, ) - vllm_dtype: Optional[str] = Field( + vllm_dtype: str | None = Field( default="auto", json_schema_extra={"description": "Data type for VLLM"}, ) - vllm_max_model_len: Optional[int] = Field( + vllm_max_model_len: int | None = Field( default=None, json_schema_extra={ "description": "Maximum length of the model context for VLLM" }, ) - reward_funcs: Optional[list[str]] = Field( + reward_funcs: list[str] | None = Field( default=None, json_schema_extra={"description": "List of reward functions to load"}, ) - reward_weights: Optional[list[float]] = Field( + reward_weights: list[float] | None = Field( default=None, json_schema_extra={ "description": "Weights for each reward function. Must match the number of reward functions." }, ) - num_generations: Optional[int] = Field( + num_generations: int | None = Field( default=None, json_schema_extra={ "description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value." }, ) - log_completions: Optional[bool] = Field( + log_completions: bool | None = Field( default=False, json_schema_extra={"description": "Whether to log completions"}, ) - sync_ref_model: Optional[bool] = Field( + sync_ref_model: bool | None = Field( default=False, json_schema_extra={ "description": ( @@ -77,13 +73,13 @@ class TRLConfig(BaseModel): ) }, ) - ref_model_mixup_alpha: Optional[float] = Field( + ref_model_mixup_alpha: float | None = Field( default=0.9, json_schema_extra={ "description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`." }, ) - ref_model_sync_steps: Optional[int] = Field( + ref_model_sync_steps: int | None = Field( default=64, json_schema_extra={ "description": "Sync steps for the reference model. Requires `sync_ref_model=True`." diff --git a/src/axolotl/utils/schemas/utils.py b/src/axolotl/utils/schemas/utils.py new file mode 100644 index 000000000..bf74390f6 --- /dev/null +++ b/src/axolotl/utils/schemas/utils.py @@ -0,0 +1,79 @@ +"""Utilities for Axolotl Pydantic models""" + +import logging + +LOG = logging.getLogger(__name__) + + +def handle_legacy_message_fields_logic(data: dict) -> dict: + """ + Handle backwards compatibility between legacy message field mapping and new property mapping system. + + Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options: + - message_field_role: Mapped to the role field + - message_field_content: Mapped to the content field + + The new system uses message_property_mappings to support arbitrary field mappings: + message_property_mappings: + role: source_role_field + content: source_content_field + additional_field: source_field + + Args: + data: Dictionary containing configuration data + + Returns: + Updated dictionary with message field mappings consolidated + + Raises: + ValueError: If there are conflicts between legacy and new mappings + """ + data = data.copy() # Create a copy to avoid modifying the original + + if data.get("message_property_mappings") is None: + data["message_property_mappings"] = {} + + # Check for conflicts and handle role + if "message_field_role" in data: + LOG.warning( + "message_field_role is deprecated, use message_property_mappings instead. " + f"Example: message_property_mappings: {{role: {data['message_field_role']}}}" + ) + if ( + "role" in data["message_property_mappings"] + and data["message_property_mappings"]["role"] != data["message_field_role"] + ): + raise ValueError( + f"Conflicting message role fields: message_field_role='{data['message_field_role']}' " + f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'" + ) + data["message_property_mappings"]["role"] = data["message_field_role"] or "role" + + del data["message_field_role"] + elif "role" not in data["message_property_mappings"]: + data["message_property_mappings"]["role"] = "role" + + # Check for conflicts and handle content + if "message_field_content" in data: + LOG.warning( + "message_field_content is deprecated, use message_property_mappings instead. " + f"Example: message_property_mappings: {{content: {data['message_field_content']}}}" + ) + if ( + "content" in data["message_property_mappings"] + and data["message_property_mappings"]["content"] + != data["message_field_content"] + ): + raise ValueError( + f"Conflicting message content fields: message_field_content='{data['message_field_content']}' " + f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'" + ) + data["message_property_mappings"]["content"] = ( + data["message_field_content"] or "content" + ) + + del data["message_field_content"] + elif "content" not in data["message_property_mappings"]: + data["message_property_mappings"]["content"] = "content" + + return data diff --git a/styles.css b/styles.css index 749ff4366..c5b0768fa 100644 --- a/styles.css +++ b/styles.css @@ -14,7 +14,7 @@ h1 { font-family: var(--font-title); font-weight: 400; - font-size: 5rem; + font-size: 3rem; line-height: 1.1; letter-spacing: -0.05em; font-feature-settings: "ss01" on; @@ -24,7 +24,7 @@ h1 { h2 { font-family: var(--font-title); font-weight: 500; - font-size: 2rem; + font-size: 1.5rem; line-height: 1.2; letter-spacing: -0.03em; font-feature-settings: "ss01" on; @@ -35,7 +35,7 @@ h3, h4 { font-family: var(--font-body); font-weight: 400; - font-size: 1.5rem; + font-size: 1.25rem; line-height: 1.5; letter-spacing: -0.02em; } @@ -191,3 +191,87 @@ code span.er { color: #5cb85c !important; text-decoration: none !important; } + +/* API Documentation Styling */ + +/* Improve docstring section rendering */ +.level3 p { + white-space: pre-line !important; +} + +/* Format docstring sections */ +.level3 p strong { + display: block; + margin-top: 1em; + font-weight: bold; + color: var(--cyan); +} + +/* Add spacing after sections */ +.level3 p:has(strong) { + margin-bottom: 0.5em; +} + +/* Format Args and Returns sections */ +p:has(code) { + line-height: 1.6; +} + +/* Function signatures */ +.sourceCode { + margin-bottom: 1.5em; +} + +/* Parameter tables */ +.doc-section-parameters table, +.doc-section-returns table { + margin-top: 1em; + margin-bottom: 1.5em; +} + +/* Make parameter and returns headers smaller */ +h2.anchored[data-anchor-id="parameters"], +h2.anchored[data-anchor-id="returns"], +.doc-section-parameters h4, +.doc-section-returns h4 { + font-size: 1.25rem; + margin-top: 2rem; + margin-bottom: 1rem; + color: var(--lime); + border-bottom: 1px solid var(--lime); + padding-bottom: 0.3rem; + font-family: var(--font-body); + font-weight: 500; + letter-spacing: normal; +} + +/* Style documentation tables */ +table { + width: 100%; + margin-bottom: 1.5rem; + border-collapse: collapse; +} + +table th { + background-color: #1a1a1a; + padding: 0.5rem 1rem; + border-bottom: 2px solid var(--greige-600); + text-align: left; +} + +table td { + padding: 0.5rem 1rem; + border-bottom: 1px solid var(--greige-600); +} + +/* Code in table cells */ +table td code { + background-color: transparent !important; + padding: 0; +} + +/* Improve spacing in parameter and return tables */ +.doc-section-parameters, +.doc-section-returns { + margin-top: 1rem; +} diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 9d41dac76..3262a6981 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -11,10 +11,10 @@ from pydantic import ValidationError from axolotl.utils import is_comet_available from axolotl.utils.config import validate_config -from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities from axolotl.utils.dict import DictDefault from axolotl.utils.mlflow_ import setup_mlflow_env_vars from axolotl.utils.models import check_model_config +from axolotl.utils.schemas.config import AxolotlConfigWCapabilities from axolotl.utils.wandb_ import setup_wandb_env_vars warnings.filterwarnings("error") diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py index 5c1b5a1f7..47d10ee99 100644 --- a/tests/test_validation_dataset.py +++ b/tests/test_validation_dataset.py @@ -6,8 +6,8 @@ from typing import Optional import pytest from axolotl.utils.config import validate_config -from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate from axolotl.utils.dict import DictDefault +from axolotl.utils.schemas.datasets import ChatTemplate warnings.filterwarnings("error") diff --git a/tests/utils/test_models.py b/tests/utils/test_models.py index e78cdb5d7..83678430a 100644 --- a/tests/utils/test_models.py +++ b/tests/utils/test_models.py @@ -119,7 +119,7 @@ class TestModelsUtils: def test_message_property_mapping(self): """Test message property mapping configuration validation""" - from axolotl.utils.config.models.input.v0_4_1 import SFTDataset + from axolotl.utils.schemas.datasets import SFTDataset # Test legacy fields are mapped orrectly dataset = SFTDataset(