Compare commits
18 Commits
colab-misc
...
quartodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bffef25d0 | ||
|
|
94c00c1d04 | ||
|
|
ddd84d7c65 | ||
|
|
42bdf0bd74 | ||
|
|
b03d96a228 | ||
|
|
2653f170fc | ||
|
|
3bfcce9f0a | ||
|
|
8feb746953 | ||
|
|
a563815fe7 | ||
|
|
81f2203151 | ||
|
|
5b7e688fc5 | ||
|
|
5134aa66cd | ||
|
|
ba9a867adb | ||
|
|
c618f42c39 | ||
|
|
fc1f985296 | ||
|
|
a5e37f183c | ||
|
|
e6a7bbe9ff | ||
|
|
e4fd7aad0b |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: install dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
uses: quarto-dev/quarto-actions/publish@v2
|
||||
with:
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
||||
submit.sh
|
||||
*.out*
|
||||
|
||||
# Quartodoc generated files
|
||||
objects.json
|
||||
site_libs/
|
||||
|
||||
typings/
|
||||
out/
|
||||
|
||||
|
||||
@@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
193
_quarto.yml
193
_quarto.yml
@@ -1,6 +1,178 @@
|
||||
project:
|
||||
type: website
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
package: axolotl
|
||||
title: API Reference
|
||||
parser: google
|
||||
|
||||
sections:
|
||||
- title: Core
|
||||
desc: Core functionality for training
|
||||
contents:
|
||||
- train
|
||||
- evaluate
|
||||
- datasets
|
||||
- convert
|
||||
- prompt_tokenizers
|
||||
- logging_config
|
||||
- core.trainer_builder
|
||||
- core.training_args
|
||||
- core.chat.messages
|
||||
- core.chat.format.chatml
|
||||
- core.chat.format.llama3x
|
||||
- core.chat.format.shared
|
||||
- core.datasets.chat
|
||||
- core.datasets.transforms.chat_builder
|
||||
- title: CLI
|
||||
desc: Command-line interface
|
||||
contents:
|
||||
- cli.main
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
- prompt_strategies.base
|
||||
- prompt_strategies.chat_template
|
||||
- prompt_strategies.alpaca_chat
|
||||
- prompt_strategies.alpaca_instruct
|
||||
- prompt_strategies.alpaca_w_system
|
||||
- prompt_strategies.user_defined
|
||||
- prompt_strategies.llama2_chat
|
||||
- prompt_strategies.completion
|
||||
- prompt_strategies.input_output
|
||||
- prompt_strategies.stepwise_supervised
|
||||
- prompt_strategies.metharme
|
||||
- prompt_strategies.orcamini
|
||||
- prompt_strategies.pygmalion
|
||||
- prompt_strategies.messages.chat
|
||||
- prompt_strategies.dpo.chat_template
|
||||
- prompt_strategies.dpo.llama3
|
||||
- prompt_strategies.dpo.chatml
|
||||
- prompt_strategies.dpo.zephyr
|
||||
- prompt_strategies.dpo.user_defined
|
||||
- prompt_strategies.dpo.passthrough
|
||||
- prompt_strategies.kto.llama3
|
||||
- prompt_strategies.kto.chatml
|
||||
- prompt_strategies.kto.user_defined
|
||||
- prompt_strategies.orpo.chat_template
|
||||
- prompt_strategies.bradley_terry.llama3
|
||||
- title: Kernels
|
||||
desc: Low-level performance optimizations
|
||||
contents:
|
||||
- kernels.lora
|
||||
- kernels.geglu
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: MonkeyPatches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
- monkeypatch.llama_attn_hijack_xformers
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.attention.mllama
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- title: Utils
|
||||
desc: Utility functions
|
||||
contents:
|
||||
- utils.models
|
||||
- utils.tokenization
|
||||
- utils.chat_templates
|
||||
- utils.lora
|
||||
- utils.lora_embeddings
|
||||
- utils.model_shard_quant
|
||||
- utils.bench
|
||||
- utils.freeze
|
||||
- utils.trainer
|
||||
- utils.schedulers
|
||||
- utils.distributed
|
||||
- utils.dict
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
- utils.schemas.config
|
||||
- utils.schemas.model
|
||||
- utils.schemas.training
|
||||
- utils.schemas.datasets
|
||||
- utils.schemas.peft
|
||||
- utils.schemas.trl
|
||||
- utils.schemas.integrations
|
||||
- utils.schemas.enums
|
||||
- utils.schemas.utils
|
||||
- title: Integrations
|
||||
desc: Third-party integrations and extensions
|
||||
contents:
|
||||
- integrations.base
|
||||
- integrations.cut_cross_entropy.args
|
||||
- integrations.grokfast.optimizer
|
||||
- integrations.kd.trainer
|
||||
- integrations.liger.args
|
||||
- integrations.lm_eval.args
|
||||
- integrations.spectrum.args
|
||||
- title: Common
|
||||
desc: Common utilities and shared functionality
|
||||
contents:
|
||||
- common.architectures
|
||||
- common.const
|
||||
- common.datasets
|
||||
- title: Models
|
||||
desc: Custom model implementations
|
||||
contents:
|
||||
- models.mamba.modeling_mamba
|
||||
- title: Data Processing
|
||||
desc: Data processing utilities
|
||||
contents:
|
||||
- utils.collators.core
|
||||
- utils.collators.batching
|
||||
- utils.collators.mamba
|
||||
- utils.collators.mm_chat
|
||||
- utils.samplers.multipack
|
||||
- title: Callbacks
|
||||
desc: Training callbacks
|
||||
contents:
|
||||
- utils.callbacks.perplexity
|
||||
- utils.callbacks.profiler
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
@@ -35,6 +207,8 @@ website:
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
@@ -80,3 +254,22 @@ format:
|
||||
theme: darkly
|
||||
css: styles.css
|
||||
toc: true
|
||||
# Enable better handling of line breaks in markdown
|
||||
preserve-tabs: true
|
||||
html-math-method: mathjax
|
||||
# Improved markdown processing options
|
||||
md-extensions:
|
||||
- markdown_it
|
||||
- def_list
|
||||
- attr_list
|
||||
- fenced_divs
|
||||
- tables
|
||||
- html_admonition
|
||||
- lineblocks
|
||||
- fancy_lists
|
||||
# Control whitespace handling
|
||||
whitespace: preserve
|
||||
# Process newlines in paragraphs
|
||||
wrap: preserve
|
||||
# Better line break handling
|
||||
preserve-linebreaks: true
|
||||
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
/.quarto/
|
||||
_site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "CLI Reference"
|
||||
title: "Command Line Interface (CLI)"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
|
||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
||||
## Overview
|
||||
|
||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
the [dataset format](dataset-formats) and prompt strategies to:
|
||||
|
||||
- parse the dataset based on the *dataset format*
|
||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||
|
||||
@@ -2,3 +2,5 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
|
||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
@click.group()
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
"""Builder for the training args and trainer"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
||||
|
||||
|
||||
def evaluate_dataset(
|
||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""Helper function to evaluate a single dataset safely.
|
||||
"""Helper function to evaluate a single dataset.
|
||||
|
||||
Args:
|
||||
trainer: The trainer instance
|
||||
dataset: Dataset to evaluate
|
||||
dataset_type: Type of dataset ('train' or 'eval')
|
||||
flash_optimum: Whether to use flash optimum
|
||||
trainer: The trainer instance.
|
||||
dataset: Dataset to evaluate.
|
||||
dataset_type: Type of dataset ('train' or 'eval').
|
||||
flash_optimum: Whether to use flash optimum.
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics or None if dataset is None
|
||||
Dictionary of metrics or None if dataset is None.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
||||
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- The model (either PeftModel or PreTrainedModel)
|
||||
- The tokenizer
|
||||
- Dictionary of evaluation metrics
|
||||
Dictionary mapping metric names to their values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
|
||||
@@ -11,19 +11,17 @@
|
||||
# the License.
|
||||
|
||||
"""
|
||||
module to handle merging the plugins' input arguments with the base configurations.
|
||||
Module to handle merging the plugins' input arguments with the base configurations.
|
||||
|
||||
this was moved here to prevent circular imports
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_input_args():
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
|
||||
@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
from axolotl.utils.distributed import (
|
||||
barrier,
|
||||
broadcast_dict,
|
||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
||||
is_main_process,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
@@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
DPODataset,
|
||||
KTODataset,
|
||||
SFTDataset,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model_config
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
165
src/axolotl/utils/schemas/datasets.py
Normal file
165
src/axolotl/utils/schemas/datasets.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Pydantic models for datasets-related configuration"""
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from axolotl.utils.schemas.enums import ChatTemplate
|
||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
"""Structure for user defined prompt types"""
|
||||
|
||||
system_prompt: str | None = None
|
||||
system_format: str | None = None
|
||||
field_system: str | None = None
|
||||
field_instruction: str | None = None
|
||||
field_input: str | None = None
|
||||
field_output: str | None = None
|
||||
|
||||
format: str | None = None
|
||||
no_input_format: str | None = None
|
||||
field: str | None = None
|
||||
|
||||
|
||||
class SFTDataset(BaseModel):
|
||||
"""SFT configuration subset"""
|
||||
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
type: str | UserDefinedPrompterType | None = None
|
||||
input_transform: str | None = None
|
||||
shards: int | None = None
|
||||
shards_idx: int | None = None
|
||||
preprocess_shards: int | None = None
|
||||
conversation: str | None = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: ChatTemplate | str | None = None
|
||||
chat_template_jinja: str | None = None
|
||||
data_files: str | list[str] | None = None
|
||||
input_format: str | None = None
|
||||
name: str | None = None
|
||||
ds_type: str | None = None
|
||||
train_on_split: str | None = None
|
||||
field: str | None = None
|
||||
field_human: str | None = None
|
||||
field_model: str | None = None
|
||||
field_messages: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_role: str | None = None
|
||||
# deprecated, use message_property_mappings
|
||||
message_field_content: str | None = None
|
||||
message_property_mappings: dict[str, str] | None = None
|
||||
message_field_training: str | None = None
|
||||
message_field_training_detail: str | None = None
|
||||
logprobs_field: str | None = None
|
||||
temperature: float | None = None
|
||||
roles_to_train: list[str] | None = None
|
||||
train_on_eos: str | None = None
|
||||
roles: dict[str, list[str]] | None = None
|
||||
drop_system_message: bool | None = None
|
||||
trust_remote_code: bool | None = False
|
||||
revision: str | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def handle_legacy_message_fields(cls, data):
|
||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||
return handle_legacy_message_fields_logic(data)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
# pylint: disable=duplicate-code
|
||||
def check_chat_template_config(cls, data):
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump()
|
||||
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.jinja
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class PretrainingDataset(BaseModel):
|
||||
"""Pretraining dataset configuration subset"""
|
||||
|
||||
name: str | None = None
|
||||
path: str | None = None
|
||||
split: str | None = "train"
|
||||
text_column: str | None = "text"
|
||||
type: str | None = "pretrain"
|
||||
trust_remote_code: bool | None = False
|
||||
data_files: str | None = None
|
||||
skip: int | None = None
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
"""User defined typing for DPO"""
|
||||
|
||||
field_system: str | None = None
|
||||
field_prompt: str | None = None
|
||||
field_chosen: str | None = None
|
||||
field_rejected: str | None = None
|
||||
prompt_format: str | None = None
|
||||
chosen_format: str | None = None
|
||||
rejected_format: str | None = None
|
||||
|
||||
|
||||
class DPODataset(BaseModel):
|
||||
"""DPO configuration subset"""
|
||||
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
type: UserDefinedDPOType | str | None = None
|
||||
data_files: list[str] | None = None
|
||||
revision: str | None = None
|
||||
field_messages: str | None = None
|
||||
|
||||
|
||||
class StepwiseSupervisedDataset(BaseModel):
|
||||
"""Stepwise supervised dataset configuration subset"""
|
||||
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
data_files: list[str] | None = None
|
||||
revision: str | None = None
|
||||
step_separator: str | None = None
|
||||
max_completion_length: int | None = None
|
||||
train_on_last_step_only: bool | None = None
|
||||
|
||||
|
||||
class UserDefinedKTOType(BaseModel):
|
||||
"""User defined typing for KTO"""
|
||||
|
||||
field_system: str | None = None
|
||||
field_prompt: str | None = None
|
||||
field_completion: str | None = None
|
||||
field_label: bool | None = None
|
||||
prompt_format: str | None = None
|
||||
completion_format: str | None = None
|
||||
|
||||
|
||||
class KTODataset(BaseModel):
|
||||
"""KTO configuration subset"""
|
||||
|
||||
path: str | None = None
|
||||
split: str | None = None
|
||||
type: UserDefinedKTOType | str | None = None
|
||||
data_files: list[str] | None = None
|
||||
trust_remote_code: bool | None = False
|
||||
revision: str | None = None
|
||||
|
||||
|
||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||
68
src/axolotl/utils/schemas/deprecated.py
Normal file
68
src/axolotl/utils/schemas/deprecated.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
max_packed_sequence_len: int | None = None
|
||||
rope_scaling: Any | None = None
|
||||
noisy_embedding_alpha: float | None = None
|
||||
dpo_beta: float | None = None
|
||||
evaluation_strategy: str | None = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
||||
if max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
return max_packed_sequence_len
|
||||
|
||||
@field_validator("rope_scaling")
|
||||
@classmethod
|
||||
def validate_rope_scaling(cls, rope_scaling):
|
||||
if rope_scaling:
|
||||
raise DeprecationWarning(
|
||||
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
||||
)
|
||||
return rope_scaling
|
||||
|
||||
@field_validator("noisy_embedding_alpha")
|
||||
@classmethod
|
||||
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
||||
if noisy_embedding_alpha:
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
return noisy_embedding_alpha
|
||||
|
||||
@field_validator("dpo_beta")
|
||||
@classmethod
|
||||
def validate_dpo_beta(cls, dpo_beta):
|
||||
if dpo_beta is not None:
|
||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||
return dpo_beta
|
||||
|
||||
@field_validator("evaluation_strategy")
|
||||
@classmethod
|
||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||
if evaluation_strategy is not None:
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""Parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||
default=None, alias="model_kwargs"
|
||||
)
|
||||
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||
49
src/axolotl/utils/schemas/enums.py
Normal file
49
src/axolotl/utils/schemas/enums.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Enums for Axolotl input config"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
grpo = "grpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CustomSupportedOptimizers(str, Enum):
|
||||
"""Custom supported optimizers"""
|
||||
|
||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
108
src/axolotl/utils/schemas/integrations.py
Normal file
108
src/axolotl/utils/schemas/integrations.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Pydantic models for Axolotl integrations"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MLFlowConfig(BaseModel):
|
||||
"""MLFlow configuration subset"""
|
||||
|
||||
use_mlflow: bool | None = None
|
||||
mlflow_tracking_uri: str | None = None
|
||||
mlflow_experiment_name: str | None = None
|
||||
mlflow_run_name: str | None = None
|
||||
hf_mlflow_log_artifacts: bool | None = None
|
||||
|
||||
|
||||
class LISAConfig(BaseModel):
|
||||
"""LISA configuration subset"""
|
||||
|
||||
lisa_n_layers: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: str | None = Field(
|
||||
default="model.layers",
|
||||
json_schema_extra={"description": "path under the model to access the layers"},
|
||||
)
|
||||
|
||||
|
||||
class WandbConfig(BaseModel):
|
||||
"""Wandb configuration subset"""
|
||||
|
||||
use_wandb: bool | None = None
|
||||
wandb_name: str | None = None
|
||||
wandb_run_id: str | None = None
|
||||
wandb_mode: str | None = None
|
||||
wandb_project: str | None = None
|
||||
wandb_entity: str | None = None
|
||||
wandb_watch: str | None = None
|
||||
wandb_log_model: str | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_wandb_run(cls, data):
|
||||
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
||||
data["wandb_name"] = data.get("wandb_run_id")
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: bool | None = None
|
||||
comet_api_key: str | None = None
|
||||
comet_workspace: str | None = None
|
||||
comet_project_name: str | None = None
|
||||
comet_experiment_key: str | None = None
|
||||
comet_mode: str | None = None
|
||||
comet_online: bool | None = None
|
||||
comet_experiment_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
gradio_title: str | None = None
|
||||
gradio_share: bool | None = None
|
||||
gradio_server_name: str | None = None
|
||||
gradio_server_port: int | None = None
|
||||
gradio_max_new_tokens: int | None = None
|
||||
gradio_temperature: float | None = None
|
||||
|
||||
|
||||
class RayConfig(BaseModel):
|
||||
"""Ray launcher configuration subset"""
|
||||
|
||||
use_ray: bool = Field(default=False)
|
||||
ray_run_name: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||
},
|
||||
)
|
||||
ray_num_workers: int = Field(
|
||||
default=1,
|
||||
json_schema_extra={
|
||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||
},
|
||||
)
|
||||
resources_per_worker: dict = Field(
|
||||
default_factory=lambda: {"GPU": 1},
|
||||
json_schema_extra={
|
||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||
},
|
||||
)
|
||||
55
src/axolotl/utils/schemas/model.py
Normal file
55
src/axolotl/utils/schemas/model.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Pydantic models for model input / output, etc. configuration"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelInputConfig(BaseModel):
|
||||
"""Model configuration subset"""
|
||||
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
base_model: str
|
||||
base_model_config: str | None = None
|
||||
cls_model_config: str | None = None
|
||||
tokenizer_config: str | None = None
|
||||
tokenizer_use_fast: bool | None = None
|
||||
tokenizer_legacy: bool | None = None
|
||||
tokenizer_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
)
|
||||
processor_type: str | None = Field(
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: bool | None = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
if trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
return trust_remote_code
|
||||
|
||||
|
||||
class ModelOutputConfig(BaseModel):
|
||||
"""model save configuration subset"""
|
||||
|
||||
output_dir: str = Field(default="./model-out")
|
||||
hub_model_id: str | None = None
|
||||
hub_strategy: str | None = None
|
||||
save_safetensors: bool | None = True
|
||||
|
||||
|
||||
class SpecialTokensConfig(BaseModel):
|
||||
"""Special tokens configuration subset"""
|
||||
|
||||
bos_token: str | None = None
|
||||
eos_token: str | None = None
|
||||
pad_token: str | None = None
|
||||
unk_token: str | None = None
|
||||
additional_special_tokens: list[str] | None = None
|
||||
132
src/axolotl/utils/schemas/peft.py
Normal file
132
src/axolotl/utils/schemas/peft.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Pydantic models for PEFT-related configuration"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
|
||||
|
||||
class PeftConfig(BaseModel):
|
||||
"""peftq configuration subset"""
|
||||
|
||||
loftq_config: LoftQConfig | None = None
|
||||
|
||||
|
||||
class LoraConfig(BaseModel):
|
||||
"""Peft / LoRA configuration subset"""
|
||||
|
||||
load_in_8bit: bool | None = Field(default=False)
|
||||
load_in_4bit: bool | None = Field(default=False)
|
||||
|
||||
adapter: str | None = None
|
||||
lora_model_dir: str | None = None
|
||||
lora_r: int | None = None
|
||||
lora_alpha: int | None = None
|
||||
lora_fan_in_fan_out: bool | None = None
|
||||
lora_target_modules: str | list[str] | None = None
|
||||
lora_target_linear: bool | None = None
|
||||
lora_modules_to_save: list[str] | None = None
|
||||
lora_dropout: float | None = 0.0
|
||||
peft_layers_to_transform: list[int] | None = None
|
||||
peft_layers_pattern: list[str] | None = None
|
||||
peft: PeftConfig | None = None
|
||||
peft_use_dora: bool | None = None
|
||||
peft_use_rslora: bool | None = None
|
||||
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||
peft_init_lora_weights: bool | str | None = None
|
||||
|
||||
qlora_sharded_model_loading: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: bool | None = None
|
||||
gptq: bool | None = None
|
||||
bnb_config_kwargs: dict[str, Any] | None = None
|
||||
|
||||
loraplus_lr_ratio: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
},
|
||||
)
|
||||
loraplus_lr_embedding: float | None = Field(
|
||||
default=1e-6,
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
},
|
||||
)
|
||||
|
||||
merge_lora: bool | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_adapter(cls, data):
|
||||
if (
|
||||
not data.get("adapter")
|
||||
and not data.get("inference")
|
||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||
):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_qlora(self):
|
||||
if self.adapter == "qlora":
|
||||
if self.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if self.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if self.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
|
||||
if self.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if self.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if self.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
|
||||
if not self.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
return self
|
||||
|
||||
@field_validator("loraplus_lr_embedding")
|
||||
@classmethod
|
||||
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||
return loraplus_lr_embedding
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_lora_dropout(cls, data):
|
||||
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
relora_steps: int | None = None
|
||||
relora_warmup_steps: int | None = None
|
||||
relora_anneal_steps: int | None = None
|
||||
relora_prune_ratio: float | None = None
|
||||
relora_cpu_offload: bool | None = None
|
||||
99
src/axolotl/utils/schemas/training.py
Normal file
99
src/axolotl/utils/schemas/training.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Pydantic models for training hyperparameters"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LrGroup(BaseModel):
|
||||
"""Custom learning rate group configuration"""
|
||||
|
||||
name: str
|
||||
modules: list[str]
|
||||
lr: float
|
||||
|
||||
|
||||
class HyperparametersConfig(BaseModel):
|
||||
"""Training hyperparams configuration subset"""
|
||||
|
||||
gradient_accumulation_steps: int | None = Field(default=1)
|
||||
micro_batch_size: int | None = Field(
|
||||
default=1,
|
||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||
)
|
||||
batch_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Total batch size, we do not recommended setting this manually"
|
||||
},
|
||||
)
|
||||
eval_batch_size: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
},
|
||||
)
|
||||
|
||||
auto_find_batch_size: bool | None = None
|
||||
|
||||
train_on_inputs: bool | None = False
|
||||
group_by_length: bool | None = None
|
||||
|
||||
learning_rate: str | float
|
||||
embedding_lr: float | None = None
|
||||
embedding_lr_scale: float | None = None
|
||||
weight_decay: float | None = 0.0
|
||||
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||
OptimizerNames.ADAMW_TORCH_FUSED
|
||||
)
|
||||
optim_args: (str | dict[str, Any]) | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
)
|
||||
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
},
|
||||
)
|
||||
torchdistx_path: str | None = None
|
||||
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||
SchedulerType.COSINE
|
||||
)
|
||||
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||
lr_quadratic_warmup: bool | None = None
|
||||
cosine_min_lr_ratio: float | None = None
|
||||
cosine_constant_lr_ratio: float | None = None
|
||||
lr_div_factor: float | None = None
|
||||
lr_groups: list[LrGroup] | None = None
|
||||
|
||||
adam_epsilon: float | None = None
|
||||
adam_beta1: float | None = None
|
||||
adam_beta2: float | None = None
|
||||
max_grad_norm: float | None = None
|
||||
num_epochs: float = Field(default=1.0)
|
||||
|
||||
@field_validator("batch_size")
|
||||
@classmethod
|
||||
def hint_batch_size_set(cls, batch_size):
|
||||
if batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
return batch_size
|
||||
|
||||
@field_validator("learning_rate")
|
||||
@classmethod
|
||||
def convert_learning_rate(cls, learning_rate):
|
||||
if learning_rate and isinstance(learning_rate, str):
|
||||
learning_rate = float(learning_rate)
|
||||
return learning_rate
|
||||
@@ -1,8 +1,4 @@
|
||||
"""
|
||||
GRPO specific configuration args
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
"""Pydantic models for TRL trainer configuration"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -12,11 +8,11 @@ class TRLConfig(BaseModel):
|
||||
Input args for TRL.
|
||||
"""
|
||||
|
||||
beta: Optional[float] = Field(
|
||||
beta: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Beta for RL training"},
|
||||
)
|
||||
max_completion_length: Optional[int] = Field(
|
||||
max_completion_length: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the completion for RL training"
|
||||
@@ -25,50 +21,50 @@ class TRLConfig(BaseModel):
|
||||
|
||||
# GRPO specific args
|
||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||
use_vllm: Optional[bool] = Field(
|
||||
use_vllm: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||
)
|
||||
vllm_device: Optional[str] = Field(
|
||||
vllm_device: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Device to use for VLLM"},
|
||||
)
|
||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||
vllm_gpu_memory_utilization: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||
)
|
||||
vllm_dtype: Optional[str] = Field(
|
||||
vllm_dtype: str | None = Field(
|
||||
default="auto",
|
||||
json_schema_extra={"description": "Data type for VLLM"},
|
||||
)
|
||||
vllm_max_model_len: Optional[int] = Field(
|
||||
vllm_max_model_len: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Maximum length of the model context for VLLM"
|
||||
},
|
||||
)
|
||||
|
||||
reward_funcs: Optional[list[str]] = Field(
|
||||
reward_funcs: list[str] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of reward functions to load"},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = Field(
|
||||
reward_weights: list[float] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||
},
|
||||
)
|
||||
num_generations: Optional[int] = Field(
|
||||
num_generations: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||
},
|
||||
)
|
||||
log_completions: Optional[bool] = Field(
|
||||
log_completions: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={"description": "Whether to log completions"},
|
||||
)
|
||||
sync_ref_model: Optional[bool] = Field(
|
||||
sync_ref_model: bool | None = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"description": (
|
||||
@@ -77,13 +73,13 @@ class TRLConfig(BaseModel):
|
||||
)
|
||||
},
|
||||
)
|
||||
ref_model_mixup_alpha: Optional[float] = Field(
|
||||
ref_model_mixup_alpha: float | None = Field(
|
||||
default=0.9,
|
||||
json_schema_extra={
|
||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||
},
|
||||
)
|
||||
ref_model_sync_steps: Optional[int] = Field(
|
||||
ref_model_sync_steps: int | None = Field(
|
||||
default=64,
|
||||
json_schema_extra={
|
||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||
79
src/axolotl/utils/schemas/utils.py
Normal file
79
src/axolotl/utils/schemas/utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Utilities for Axolotl Pydantic models"""
|
||||
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
"""
|
||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||
|
||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||
- message_field_role: Mapped to the role field
|
||||
- message_field_content: Mapped to the content field
|
||||
|
||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||
message_property_mappings:
|
||||
role: source_role_field
|
||||
content: source_content_field
|
||||
additional_field: source_field
|
||||
|
||||
Args:
|
||||
data: Dictionary containing configuration data
|
||||
|
||||
Returns:
|
||||
Updated dictionary with message field mappings consolidated
|
||||
|
||||
Raises:
|
||||
ValueError: If there are conflicts between legacy and new mappings
|
||||
"""
|
||||
data = data.copy() # Create a copy to avoid modifying the original
|
||||
|
||||
if data.get("message_property_mappings") is None:
|
||||
data["message_property_mappings"] = {}
|
||||
|
||||
# Check for conflicts and handle role
|
||||
if "message_field_role" in data:
|
||||
LOG.warning(
|
||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||
)
|
||||
if (
|
||||
"role" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||
)
|
||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||
|
||||
del data["message_field_role"]
|
||||
elif "role" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["role"] = "role"
|
||||
|
||||
# Check for conflicts and handle content
|
||||
if "message_field_content" in data:
|
||||
LOG.warning(
|
||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||
)
|
||||
if (
|
||||
"content" in data["message_property_mappings"]
|
||||
and data["message_property_mappings"]["content"]
|
||||
!= data["message_field_content"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||
)
|
||||
data["message_property_mappings"]["content"] = (
|
||||
data["message_field_content"] or "content"
|
||||
)
|
||||
|
||||
del data["message_field_content"]
|
||||
elif "content" not in data["message_property_mappings"]:
|
||||
data["message_property_mappings"]["content"] = "content"
|
||||
|
||||
return data
|
||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
||||
h1 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 400;
|
||||
font-size: 5rem;
|
||||
font-size: 3rem;
|
||||
line-height: 1.1;
|
||||
letter-spacing: -0.05em;
|
||||
font-feature-settings: "ss01" on;
|
||||
@@ -24,7 +24,7 @@ h1 {
|
||||
h2 {
|
||||
font-family: var(--font-title);
|
||||
font-weight: 500;
|
||||
font-size: 2rem;
|
||||
font-size: 1.5rem;
|
||||
line-height: 1.2;
|
||||
letter-spacing: -0.03em;
|
||||
font-feature-settings: "ss01" on;
|
||||
@@ -35,7 +35,7 @@ h3,
|
||||
h4 {
|
||||
font-family: var(--font-body);
|
||||
font-weight: 400;
|
||||
font-size: 1.5rem;
|
||||
font-size: 1.25rem;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.02em;
|
||||
}
|
||||
@@ -191,3 +191,87 @@ code span.er {
|
||||
color: #5cb85c !important;
|
||||
text-decoration: none !important;
|
||||
}
|
||||
|
||||
/* API Documentation Styling */
|
||||
|
||||
/* Improve docstring section rendering */
|
||||
.level3 p {
|
||||
white-space: pre-line !important;
|
||||
}
|
||||
|
||||
/* Format docstring sections */
|
||||
.level3 p strong {
|
||||
display: block;
|
||||
margin-top: 1em;
|
||||
font-weight: bold;
|
||||
color: var(--cyan);
|
||||
}
|
||||
|
||||
/* Add spacing after sections */
|
||||
.level3 p:has(strong) {
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
|
||||
/* Format Args and Returns sections */
|
||||
p:has(code) {
|
||||
line-height: 1.6;
|
||||
}
|
||||
|
||||
/* Function signatures */
|
||||
.sourceCode {
|
||||
margin-bottom: 1.5em;
|
||||
}
|
||||
|
||||
/* Parameter tables */
|
||||
.doc-section-parameters table,
|
||||
.doc-section-returns table {
|
||||
margin-top: 1em;
|
||||
margin-bottom: 1.5em;
|
||||
}
|
||||
|
||||
/* Make parameter and returns headers smaller */
|
||||
h2.anchored[data-anchor-id="parameters"],
|
||||
h2.anchored[data-anchor-id="returns"],
|
||||
.doc-section-parameters h4,
|
||||
.doc-section-returns h4 {
|
||||
font-size: 1.25rem;
|
||||
margin-top: 2rem;
|
||||
margin-bottom: 1rem;
|
||||
color: var(--lime);
|
||||
border-bottom: 1px solid var(--lime);
|
||||
padding-bottom: 0.3rem;
|
||||
font-family: var(--font-body);
|
||||
font-weight: 500;
|
||||
letter-spacing: normal;
|
||||
}
|
||||
|
||||
/* Style documentation tables */
|
||||
table {
|
||||
width: 100%;
|
||||
margin-bottom: 1.5rem;
|
||||
border-collapse: collapse;
|
||||
}
|
||||
|
||||
table th {
|
||||
background-color: #1a1a1a;
|
||||
padding: 0.5rem 1rem;
|
||||
border-bottom: 2px solid var(--greige-600);
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
table td {
|
||||
padding: 0.5rem 1rem;
|
||||
border-bottom: 1px solid var(--greige-600);
|
||||
}
|
||||
|
||||
/* Code in table cells */
|
||||
table td code {
|
||||
background-color: transparent !important;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
/* Improve spacing in parameter and return tables */
|
||||
.doc-section-parameters,
|
||||
.doc-section-returns {
|
||||
margin-top: 1rem;
|
||||
}
|
||||
|
||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
||||
|
||||
from axolotl.utils import is_comet_available
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import Optional
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.schemas.datasets import ChatTemplate
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
"""Test message property mapping configuration validation"""
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||
from axolotl.utils.schemas.datasets import SFTDataset
|
||||
|
||||
# Test legacy fields are mapped orrectly
|
||||
dataset = SFTDataset(
|
||||
|
||||
Reference in New Issue
Block a user