Compare commits
18 Commits
attn-imple
...
quartodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bffef25d0 | ||
|
|
94c00c1d04 | ||
|
|
ddd84d7c65 | ||
|
|
42bdf0bd74 | ||
|
|
b03d96a228 | ||
|
|
2653f170fc | ||
|
|
3bfcce9f0a | ||
|
|
8feb746953 | ||
|
|
a563815fe7 | ||
|
|
81f2203151 | ||
|
|
5b7e688fc5 | ||
|
|
5134aa66cd | ||
|
|
ba9a867adb | ||
|
|
c618f42c39 | ||
|
|
fc1f985296 | ||
|
|
a5e37f183c | ||
|
|
e6a7bbe9ff | ||
|
|
e4fd7aad0b |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter
|
python3 -m pip install jupyter quartodoc
|
||||||
|
python3 -m pip install -e .
|
||||||
|
- name: Build autodoc
|
||||||
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
|
# Quartodoc generated files
|
||||||
|
objects.json
|
||||||
|
site_libs/
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
|
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
193
_quarto.yml
193
_quarto.yml
@@ -1,6 +1,178 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
|
quartodoc:
|
||||||
|
dir: docs/api
|
||||||
|
package: axolotl
|
||||||
|
title: API Reference
|
||||||
|
parser: google
|
||||||
|
|
||||||
|
sections:
|
||||||
|
- title: Core
|
||||||
|
desc: Core functionality for training
|
||||||
|
contents:
|
||||||
|
- train
|
||||||
|
- evaluate
|
||||||
|
- datasets
|
||||||
|
- convert
|
||||||
|
- prompt_tokenizers
|
||||||
|
- logging_config
|
||||||
|
- core.trainer_builder
|
||||||
|
- core.training_args
|
||||||
|
- core.chat.messages
|
||||||
|
- core.chat.format.chatml
|
||||||
|
- core.chat.format.llama3x
|
||||||
|
- core.chat.format.shared
|
||||||
|
- core.datasets.chat
|
||||||
|
- core.datasets.transforms.chat_builder
|
||||||
|
- title: CLI
|
||||||
|
desc: Command-line interface
|
||||||
|
contents:
|
||||||
|
- cli.main
|
||||||
|
- cli.train
|
||||||
|
- cli.evaluate
|
||||||
|
- cli.args
|
||||||
|
- cli.checks
|
||||||
|
- cli.config
|
||||||
|
- cli.inference
|
||||||
|
- cli.merge_lora
|
||||||
|
- cli.merge_sharded_fsdp_weights
|
||||||
|
- cli.preprocess
|
||||||
|
- cli.sweeps
|
||||||
|
- cli.utils
|
||||||
|
- cli.cloud.base
|
||||||
|
- cli.cloud.modal_
|
||||||
|
- title: Trainers
|
||||||
|
desc: Training implementations
|
||||||
|
contents:
|
||||||
|
- core.trainers.base
|
||||||
|
- core.trainers.trl
|
||||||
|
- core.trainers.dpo.trainer
|
||||||
|
- core.trainers.grpo.trainer
|
||||||
|
- title: Prompt Strategies
|
||||||
|
desc: Prompt formatting strategies
|
||||||
|
contents:
|
||||||
|
- prompt_strategies.base
|
||||||
|
- prompt_strategies.chat_template
|
||||||
|
- prompt_strategies.alpaca_chat
|
||||||
|
- prompt_strategies.alpaca_instruct
|
||||||
|
- prompt_strategies.alpaca_w_system
|
||||||
|
- prompt_strategies.user_defined
|
||||||
|
- prompt_strategies.llama2_chat
|
||||||
|
- prompt_strategies.completion
|
||||||
|
- prompt_strategies.input_output
|
||||||
|
- prompt_strategies.stepwise_supervised
|
||||||
|
- prompt_strategies.metharme
|
||||||
|
- prompt_strategies.orcamini
|
||||||
|
- prompt_strategies.pygmalion
|
||||||
|
- prompt_strategies.messages.chat
|
||||||
|
- prompt_strategies.dpo.chat_template
|
||||||
|
- prompt_strategies.dpo.llama3
|
||||||
|
- prompt_strategies.dpo.chatml
|
||||||
|
- prompt_strategies.dpo.zephyr
|
||||||
|
- prompt_strategies.dpo.user_defined
|
||||||
|
- prompt_strategies.dpo.passthrough
|
||||||
|
- prompt_strategies.kto.llama3
|
||||||
|
- prompt_strategies.kto.chatml
|
||||||
|
- prompt_strategies.kto.user_defined
|
||||||
|
- prompt_strategies.orpo.chat_template
|
||||||
|
- prompt_strategies.bradley_terry.llama3
|
||||||
|
- title: Kernels
|
||||||
|
desc: Low-level performance optimizations
|
||||||
|
contents:
|
||||||
|
- kernels.lora
|
||||||
|
- kernels.geglu
|
||||||
|
- kernels.swiglu
|
||||||
|
- kernels.quantize
|
||||||
|
- kernels.utils
|
||||||
|
- title: MonkeyPatches
|
||||||
|
desc: Runtime patches for model optimizations
|
||||||
|
contents:
|
||||||
|
- monkeypatch.llama_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_attn_hijack_xformers
|
||||||
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
|
- monkeypatch.multipack
|
||||||
|
- monkeypatch.relora
|
||||||
|
- monkeypatch.llama_expand_mask
|
||||||
|
- monkeypatch.lora_kernels
|
||||||
|
- monkeypatch.utils
|
||||||
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_patch_multipack
|
||||||
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
|
- monkeypatch.trainer_fsdp_optim
|
||||||
|
- monkeypatch.transformers_fa_utils
|
||||||
|
- monkeypatch.unsloth_
|
||||||
|
- monkeypatch.attention.mllama
|
||||||
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
|
- monkeypatch.mixtral
|
||||||
|
- title: Utils
|
||||||
|
desc: Utility functions
|
||||||
|
contents:
|
||||||
|
- utils.models
|
||||||
|
- utils.tokenization
|
||||||
|
- utils.chat_templates
|
||||||
|
- utils.lora
|
||||||
|
- utils.lora_embeddings
|
||||||
|
- utils.model_shard_quant
|
||||||
|
- utils.bench
|
||||||
|
- utils.freeze
|
||||||
|
- utils.trainer
|
||||||
|
- utils.schedulers
|
||||||
|
- utils.distributed
|
||||||
|
- utils.dict
|
||||||
|
- utils.optimizers.adopt
|
||||||
|
- utils.data.pretraining
|
||||||
|
- utils.data.sft
|
||||||
|
- utils.gradient_checkpointing.unsloth
|
||||||
|
- title: Schemas
|
||||||
|
desc: Pydantic data models for Axolotl config
|
||||||
|
contents:
|
||||||
|
- utils.schemas.config
|
||||||
|
- utils.schemas.model
|
||||||
|
- utils.schemas.training
|
||||||
|
- utils.schemas.datasets
|
||||||
|
- utils.schemas.peft
|
||||||
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.integrations
|
||||||
|
- utils.schemas.enums
|
||||||
|
- utils.schemas.utils
|
||||||
|
- title: Integrations
|
||||||
|
desc: Third-party integrations and extensions
|
||||||
|
contents:
|
||||||
|
- integrations.base
|
||||||
|
- integrations.cut_cross_entropy.args
|
||||||
|
- integrations.grokfast.optimizer
|
||||||
|
- integrations.kd.trainer
|
||||||
|
- integrations.liger.args
|
||||||
|
- integrations.lm_eval.args
|
||||||
|
- integrations.spectrum.args
|
||||||
|
- title: Common
|
||||||
|
desc: Common utilities and shared functionality
|
||||||
|
contents:
|
||||||
|
- common.architectures
|
||||||
|
- common.const
|
||||||
|
- common.datasets
|
||||||
|
- title: Models
|
||||||
|
desc: Custom model implementations
|
||||||
|
contents:
|
||||||
|
- models.mamba.modeling_mamba
|
||||||
|
- title: Data Processing
|
||||||
|
desc: Data processing utilities
|
||||||
|
contents:
|
||||||
|
- utils.collators.core
|
||||||
|
- utils.collators.batching
|
||||||
|
- utils.collators.mamba
|
||||||
|
- utils.collators.mm_chat
|
||||||
|
- utils.samplers.multipack
|
||||||
|
- title: Callbacks
|
||||||
|
desc: Training callbacks
|
||||||
|
contents:
|
||||||
|
- utils.callbacks.perplexity
|
||||||
|
- utils.callbacks.profiler
|
||||||
|
- utils.callbacks.lisa
|
||||||
|
- utils.callbacks.mlflow_
|
||||||
|
- utils.callbacks.comet_
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -35,6 +207,8 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
|
- text: "API Reference"
|
||||||
|
href: docs/api
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -80,3 +254,22 @@ format:
|
|||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
|
# Enable better handling of line breaks in markdown
|
||||||
|
preserve-tabs: true
|
||||||
|
html-math-method: mathjax
|
||||||
|
# Improved markdown processing options
|
||||||
|
md-extensions:
|
||||||
|
- markdown_it
|
||||||
|
- def_list
|
||||||
|
- attr_list
|
||||||
|
- fenced_divs
|
||||||
|
- tables
|
||||||
|
- html_admonition
|
||||||
|
- lineblocks
|
||||||
|
- fancy_lists
|
||||||
|
# Control whitespace handling
|
||||||
|
whitespace: preserve
|
||||||
|
# Process newlines in paragraphs
|
||||||
|
wrap: preserve
|
||||||
|
# Better line break handling
|
||||||
|
preserve-linebreaks: true
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
|
/api/*.qmd
|
||||||
|
/api/*.html
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "CLI Reference"
|
title: "Command Line Interface (CLI)"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
the [dataset format](dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
jupyter
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""Builder for the training args and trainer"""
|
||||||
Builder for the training args and trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance
|
trainer: The trainer instance.
|
||||||
dataset: Dataset to evaluate
|
dataset: Dataset to evaluate.
|
||||||
dataset_type: Type of dataset ('train' or 'eval')
|
dataset_type: Type of dataset ('train' or 'eval').
|
||||||
flash_optimum: Whether to use flash optimum
|
flash_optimum: Whether to use flash optimum.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None
|
Dictionary of metrics or None if dataset is None.
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Dictionary mapping metric names to their values.
|
||||||
- The model (either PeftModel or PreTrainedModel)
|
|
||||||
- The tokenizer
|
|
||||||
- Dictionary of evaluation metrics
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
|||||||
@@ -11,19 +11,17 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module to handle merging the plugins' input arguments with the base configurations.
|
Module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
this was moved here to prevent circular imports
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
@@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
DPODataset,
|
|
||||||
KTODataset,
|
|
||||||
SFTDataset,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
from axolotl.utils.schemas.config import (
|
||||||
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
165
src/axolotl/utils/schemas/datasets.py
Normal file
165
src/axolotl/utils/schemas/datasets.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Pydantic models for datasets-related configuration"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import ChatTemplate
|
||||||
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedPrompterType(BaseModel):
|
||||||
|
"""Structure for user defined prompt types"""
|
||||||
|
|
||||||
|
system_prompt: str | None = None
|
||||||
|
system_format: str | None = None
|
||||||
|
field_system: str | None = None
|
||||||
|
field_instruction: str | None = None
|
||||||
|
field_input: str | None = None
|
||||||
|
field_output: str | None = None
|
||||||
|
|
||||||
|
format: str | None = None
|
||||||
|
no_input_format: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(BaseModel):
|
||||||
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: str | UserDefinedPrompterType | None = None
|
||||||
|
input_transform: str | None = None
|
||||||
|
shards: int | None = None
|
||||||
|
shards_idx: int | None = None
|
||||||
|
preprocess_shards: int | None = None
|
||||||
|
conversation: str | None = None
|
||||||
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
|
chat_template: ChatTemplate | str | None = None
|
||||||
|
chat_template_jinja: str | None = None
|
||||||
|
data_files: str | list[str] | None = None
|
||||||
|
input_format: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
ds_type: str | None = None
|
||||||
|
train_on_split: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
field_human: str | None = None
|
||||||
|
field_model: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_role: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_content: str | None = None
|
||||||
|
message_property_mappings: dict[str, str] | None = None
|
||||||
|
message_field_training: str | None = None
|
||||||
|
message_field_training_detail: str | None = None
|
||||||
|
logprobs_field: str | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
roles_to_train: list[str] | None = None
|
||||||
|
train_on_eos: str | None = None
|
||||||
|
roles: dict[str, list[str]] | None = None
|
||||||
|
drop_system_message: bool | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def handle_legacy_message_fields(cls, data):
|
||||||
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||||
|
return handle_legacy_message_fields_logic(data)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
# Set chat_template to tokenizer_default if not set
|
||||||
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainingDataset(BaseModel):
|
||||||
|
"""Pretraining dataset configuration subset"""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = "train"
|
||||||
|
text_column: str | None = "text"
|
||||||
|
type: str | None = "pretrain"
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
data_files: str | None = None
|
||||||
|
skip: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
"""User defined typing for DPO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_chosen: str | None = None
|
||||||
|
field_rejected: str | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
chosen_format: str | None = None
|
||||||
|
rejected_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DPODataset(BaseModel):
|
||||||
|
"""DPO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedDPOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StepwiseSupervisedDataset(BaseModel):
|
||||||
|
"""Stepwise supervised dataset configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
step_separator: str | None = None
|
||||||
|
max_completion_length: int | None = None
|
||||||
|
train_on_last_step_only: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedKTOType(BaseModel):
|
||||||
|
"""User defined typing for KTO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_completion: str | None = None
|
||||||
|
field_label: bool | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
completion_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KTODataset(BaseModel):
|
||||||
|
"""KTO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedKTOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||||
68
src/axolotl/utils/schemas/deprecated.py
Normal file
68
src/axolotl/utils/schemas/deprecated.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedParameters(BaseModel):
|
||||||
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
|
max_packed_sequence_len: int | None = None
|
||||||
|
rope_scaling: Any | None = None
|
||||||
|
noisy_embedding_alpha: float | None = None
|
||||||
|
dpo_beta: float | None = None
|
||||||
|
evaluation_strategy: str | None = None
|
||||||
|
|
||||||
|
@field_validator("max_packed_sequence_len")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
||||||
|
if max_packed_sequence_len:
|
||||||
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
|
return max_packed_sequence_len
|
||||||
|
|
||||||
|
@field_validator("rope_scaling")
|
||||||
|
@classmethod
|
||||||
|
def validate_rope_scaling(cls, rope_scaling):
|
||||||
|
if rope_scaling:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
||||||
|
)
|
||||||
|
return rope_scaling
|
||||||
|
|
||||||
|
@field_validator("noisy_embedding_alpha")
|
||||||
|
@classmethod
|
||||||
|
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
||||||
|
if noisy_embedding_alpha:
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
@field_validator("dpo_beta")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_beta(cls, dpo_beta):
|
||||||
|
if dpo_beta is not None:
|
||||||
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class RemappedParameters(BaseModel):
|
||||||
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|
||||||
|
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_kwargs"
|
||||||
|
)
|
||||||
|
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||||
|
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||||
49
src/axolotl/utils/schemas/enums.py
Normal file
49
src/axolotl/utils/schemas/enums.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Enums for Axolotl input config"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RLType(str, Enum):
|
||||||
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate(str, Enum):
|
||||||
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||||
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
108
src/axolotl/utils/schemas/integrations.py
Normal file
108
src/axolotl/utils/schemas/integrations.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Pydantic models for Axolotl integrations"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLFlowConfig(BaseModel):
|
||||||
|
"""MLFlow configuration subset"""
|
||||||
|
|
||||||
|
use_mlflow: bool | None = None
|
||||||
|
mlflow_tracking_uri: str | None = None
|
||||||
|
mlflow_experiment_name: str | None = None
|
||||||
|
mlflow_run_name: str | None = None
|
||||||
|
hf_mlflow_log_artifacts: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LISAConfig(BaseModel):
|
||||||
|
"""LISA configuration subset"""
|
||||||
|
|
||||||
|
lisa_n_layers: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: str | None = Field(
|
||||||
|
default="model.layers",
|
||||||
|
json_schema_extra={"description": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
"""Wandb configuration subset"""
|
||||||
|
|
||||||
|
use_wandb: bool | None = None
|
||||||
|
wandb_name: str | None = None
|
||||||
|
wandb_run_id: str | None = None
|
||||||
|
wandb_mode: str | None = None
|
||||||
|
wandb_project: str | None = None
|
||||||
|
wandb_entity: str | None = None
|
||||||
|
wandb_watch: str | None = None
|
||||||
|
wandb_log_model: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_wandb_run(cls, data):
|
||||||
|
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
||||||
|
data["wandb_name"] = data.get("wandb_run_id")
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CometConfig(BaseModel):
|
||||||
|
"""Comet configuration subset"""
|
||||||
|
|
||||||
|
use_comet: bool | None = None
|
||||||
|
comet_api_key: str | None = None
|
||||||
|
comet_workspace: str | None = None
|
||||||
|
comet_project_name: str | None = None
|
||||||
|
comet_experiment_key: str | None = None
|
||||||
|
comet_mode: str | None = None
|
||||||
|
comet_online: bool | None = None
|
||||||
|
comet_experiment_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: str | None = None
|
||||||
|
gradio_share: bool | None = None
|
||||||
|
gradio_server_name: str | None = None
|
||||||
|
gradio_server_port: int | None = None
|
||||||
|
gradio_max_new_tokens: int | None = None
|
||||||
|
gradio_temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RayConfig(BaseModel):
|
||||||
|
"""Ray launcher configuration subset"""
|
||||||
|
|
||||||
|
use_ray: bool = Field(default=False)
|
||||||
|
ray_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ray_num_workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resources_per_worker: dict = Field(
|
||||||
|
default_factory=lambda: {"GPU": 1},
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
55
src/axolotl/utils/schemas/model.py
Normal file
55
src/axolotl/utils/schemas/model.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputConfig(BaseModel):
|
||||||
|
"""Model configuration subset"""
|
||||||
|
|
||||||
|
model_config = {"protected_namespaces": ()}
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
base_model_config: str | None = None
|
||||||
|
cls_model_config: str | None = None
|
||||||
|
tokenizer_config: str | None = None
|
||||||
|
tokenizer_use_fast: bool | None = None
|
||||||
|
tokenizer_legacy: bool | None = None
|
||||||
|
tokenizer_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||||
|
)
|
||||||
|
processor_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
|
)
|
||||||
|
trust_remote_code: bool | None = None
|
||||||
|
|
||||||
|
@field_validator("trust_remote_code")
|
||||||
|
@classmethod
|
||||||
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
if trust_remote_code:
|
||||||
|
LOG.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputConfig(BaseModel):
|
||||||
|
"""model save configuration subset"""
|
||||||
|
|
||||||
|
output_dir: str = Field(default="./model-out")
|
||||||
|
hub_model_id: str | None = None
|
||||||
|
hub_strategy: str | None = None
|
||||||
|
save_safetensors: bool | None = True
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialTokensConfig(BaseModel):
|
||||||
|
"""Special tokens configuration subset"""
|
||||||
|
|
||||||
|
bos_token: str | None = None
|
||||||
|
eos_token: str | None = None
|
||||||
|
pad_token: str | None = None
|
||||||
|
unk_token: str | None = None
|
||||||
|
additional_special_tokens: list[str] | None = None
|
||||||
132
src/axolotl/utils/schemas/peft.py
Normal file
132
src/axolotl/utils/schemas/peft.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class LoftQConfig(BaseModel):
|
||||||
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
|
loftq_bits: int = Field(
|
||||||
|
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||||
|
)
|
||||||
|
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
|
class PeftConfig(BaseModel):
|
||||||
|
"""peftq configuration subset"""
|
||||||
|
|
||||||
|
loftq_config: LoftQConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig(BaseModel):
|
||||||
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
|
load_in_8bit: bool | None = Field(default=False)
|
||||||
|
load_in_4bit: bool | None = Field(default=False)
|
||||||
|
|
||||||
|
adapter: str | None = None
|
||||||
|
lora_model_dir: str | None = None
|
||||||
|
lora_r: int | None = None
|
||||||
|
lora_alpha: int | None = None
|
||||||
|
lora_fan_in_fan_out: bool | None = None
|
||||||
|
lora_target_modules: str | list[str] | None = None
|
||||||
|
lora_target_linear: bool | None = None
|
||||||
|
lora_modules_to_save: list[str] | None = None
|
||||||
|
lora_dropout: float | None = 0.0
|
||||||
|
peft_layers_to_transform: list[int] | None = None
|
||||||
|
peft_layers_pattern: list[str] | None = None
|
||||||
|
peft: PeftConfig | None = None
|
||||||
|
peft_use_dora: bool | None = None
|
||||||
|
peft_use_rslora: bool | None = None
|
||||||
|
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||||
|
peft_init_lora_weights: bool | str | None = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
lora_on_cpu: bool | None = None
|
||||||
|
gptq: bool | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
loraplus_lr_ratio: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: float | None = Field(
|
||||||
|
default=1e-6,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate for lora embedding layers."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_lora: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_adapter(cls, data):
|
||||||
|
if (
|
||||||
|
not data.get("adapter")
|
||||||
|
and not data.get("inference")
|
||||||
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_qlora(self):
|
||||||
|
if self.adapter == "qlora":
|
||||||
|
if self.merge_lora:
|
||||||
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if self.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not self.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
@classmethod
|
||||||
|
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||||
|
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||||
|
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||||
|
return loraplus_lr_embedding
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_lora_dropout(cls, data):
|
||||||
|
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||||
|
data["lora_dropout"] = 0.0
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRAConfig(BaseModel):
|
||||||
|
"""ReLoRA configuration subset"""
|
||||||
|
|
||||||
|
relora_steps: int | None = None
|
||||||
|
relora_warmup_steps: int | None = None
|
||||||
|
relora_anneal_steps: int | None = None
|
||||||
|
relora_prune_ratio: float | None = None
|
||||||
|
relora_cpu_offload: bool | None = None
|
||||||
99
src/axolotl/utils/schemas/training.py
Normal file
99
src/axolotl/utils/schemas/training.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Pydantic models for training hyperparameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from transformers import SchedulerType
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: list[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
|
class HyperparametersConfig(BaseModel):
|
||||||
|
"""Training hyperparams configuration subset"""
|
||||||
|
|
||||||
|
gradient_accumulation_steps: int | None = Field(default=1)
|
||||||
|
micro_batch_size: int | None = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||||
|
)
|
||||||
|
batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Total batch size, we do not recommended setting this manually"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: bool | None = None
|
||||||
|
|
||||||
|
train_on_inputs: bool | None = False
|
||||||
|
group_by_length: bool | None = None
|
||||||
|
|
||||||
|
learning_rate: str | float
|
||||||
|
embedding_lr: float | None = None
|
||||||
|
embedding_lr_scale: float | None = None
|
||||||
|
weight_decay: float | None = 0.0
|
||||||
|
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||||
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
|
)
|
||||||
|
optim_args: (str | dict[str, Any]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
|
)
|
||||||
|
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torchdistx_path: str | None = None
|
||||||
|
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||||
|
SchedulerType.COSINE
|
||||||
|
)
|
||||||
|
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||||
|
lr_quadratic_warmup: bool | None = None
|
||||||
|
cosine_min_lr_ratio: float | None = None
|
||||||
|
cosine_constant_lr_ratio: float | None = None
|
||||||
|
lr_div_factor: float | None = None
|
||||||
|
lr_groups: list[LrGroup] | None = None
|
||||||
|
|
||||||
|
adam_epsilon: float | None = None
|
||||||
|
adam_beta1: float | None = None
|
||||||
|
adam_beta2: float | None = None
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
num_epochs: float = Field(default=1.0)
|
||||||
|
|
||||||
|
@field_validator("batch_size")
|
||||||
|
@classmethod
|
||||||
|
def hint_batch_size_set(cls, batch_size):
|
||||||
|
if batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
@field_validator("learning_rate")
|
||||||
|
@classmethod
|
||||||
|
def convert_learning_rate(cls, learning_rate):
|
||||||
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
|
learning_rate = float(learning_rate)
|
||||||
|
return learning_rate
|
||||||
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Pydantic models for TRL trainer configuration"""
|
||||||
GRPO specific configuration args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,11 +8,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = Field(
|
beta: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -25,50 +21,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: Optional[bool] = Field(
|
use_vllm: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: Optional[str] = Field(
|
vllm_device: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
vllm_gpu_memory_utilization: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: Optional[str] = Field(
|
vllm_dtype: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: Optional[int] = Field(
|
vllm_max_model_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[list[str]] = Field(
|
reward_funcs: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: Optional[list[float]] = Field(
|
reward_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: Optional[int] = Field(
|
num_generations: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: Optional[bool] = Field(
|
log_completions: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: Optional[bool] = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -77,13 +73,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: Optional[float] = Field(
|
ref_model_mixup_alpha: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: Optional[int] = Field(
|
ref_model_sync_steps: int | None = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
79
src/axolotl/utils/schemas/utils.py
Normal file
79
src/axolotl/utils/schemas/utils.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Utilities for Axolotl Pydantic models"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||||
|
|
||||||
|
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||||
|
- message_field_role: Mapped to the role field
|
||||||
|
- message_field_content: Mapped to the content field
|
||||||
|
|
||||||
|
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||||
|
message_property_mappings:
|
||||||
|
role: source_role_field
|
||||||
|
content: source_content_field
|
||||||
|
additional_field: source_field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated dictionary with message field mappings consolidated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are conflicts between legacy and new mappings
|
||||||
|
"""
|
||||||
|
data = data.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
|
if data.get("message_property_mappings") is None:
|
||||||
|
data["message_property_mappings"] = {}
|
||||||
|
|
||||||
|
# Check for conflicts and handle role
|
||||||
|
if "message_field_role" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"role" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||||
|
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||||
|
|
||||||
|
del data["message_field_role"]
|
||||||
|
elif "role" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["role"] = "role"
|
||||||
|
|
||||||
|
# Check for conflicts and handle content
|
||||||
|
if "message_field_content" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"content" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["message_field_content"]
|
||||||
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["content"] = "content"
|
||||||
|
|
||||||
|
return data
|
||||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 5rem;
|
font-size: 3rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 2rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.5rem;
|
font-size: 1.25rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
@@ -191,3 +191,87 @@ code span.er {
|
|||||||
color: #5cb85c !important;
|
color: #5cb85c !important;
|
||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* API Documentation Styling */
|
||||||
|
|
||||||
|
/* Improve docstring section rendering */
|
||||||
|
.level3 p {
|
||||||
|
white-space: pre-line !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format docstring sections */
|
||||||
|
.level3 p strong {
|
||||||
|
display: block;
|
||||||
|
margin-top: 1em;
|
||||||
|
font-weight: bold;
|
||||||
|
color: var(--cyan);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Add spacing after sections */
|
||||||
|
.level3 p:has(strong) {
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Format Args and Returns sections */
|
||||||
|
p:has(code) {
|
||||||
|
line-height: 1.6;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Function signatures */
|
||||||
|
.sourceCode {
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Parameter tables */
|
||||||
|
.doc-section-parameters table,
|
||||||
|
.doc-section-returns table {
|
||||||
|
margin-top: 1em;
|
||||||
|
margin-bottom: 1.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Make parameter and returns headers smaller */
|
||||||
|
h2.anchored[data-anchor-id="parameters"],
|
||||||
|
h2.anchored[data-anchor-id="returns"],
|
||||||
|
.doc-section-parameters h4,
|
||||||
|
.doc-section-returns h4 {
|
||||||
|
font-size: 1.25rem;
|
||||||
|
margin-top: 2rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
color: var(--lime);
|
||||||
|
border-bottom: 1px solid var(--lime);
|
||||||
|
padding-bottom: 0.3rem;
|
||||||
|
font-family: var(--font-body);
|
||||||
|
font-weight: 500;
|
||||||
|
letter-spacing: normal;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Style documentation tables */
|
||||||
|
table {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
border-collapse: collapse;
|
||||||
|
}
|
||||||
|
|
||||||
|
table th {
|
||||||
|
background-color: #1a1a1a;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 2px solid var(--greige-600);
|
||||||
|
text-align: left;
|
||||||
|
}
|
||||||
|
|
||||||
|
table td {
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
border-bottom: 1px solid var(--greige-600);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Code in table cells */
|
||||||
|
table td code {
|
||||||
|
background-color: transparent !important;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Improve spacing in parameter and return tables */
|
||||||
|
.doc-section-parameters,
|
||||||
|
.doc-section-returns {
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.datasets import ChatTemplate
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
|||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
"""Test message property mapping configuration validation"""
|
"""Test message property mapping configuration validation"""
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
from axolotl.utils.schemas.datasets import SFTDataset
|
||||||
|
|
||||||
# Test legacy fields are mapped orrectly
|
# Test legacy fields are mapped orrectly
|
||||||
dataset = SFTDataset(
|
dataset = SFTDataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user