Compare commits
4 Commits
quartodoc
...
pre-commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
156fede4f7 | ||
|
|
dcbbd7af79 | ||
|
|
21bac7ce1a | ||
|
|
aaa4571826 |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,12 +20,9 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: Install dependencies
|
- name: install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter quartodoc
|
python3 -m pip install jupyter
|
||||||
python3 -m pip install -e .
|
|
||||||
- name: Build autodoc
|
|
||||||
run: quartodoc build
|
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,10 +181,6 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
# Quartodoc generated files
|
|
||||||
objects.json
|
|
||||||
site_libs/
|
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,6 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
193
_quarto.yml
193
_quarto.yml
@@ -1,178 +1,6 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
quartodoc:
|
|
||||||
dir: docs/api
|
|
||||||
package: axolotl
|
|
||||||
title: API Reference
|
|
||||||
parser: google
|
|
||||||
|
|
||||||
sections:
|
|
||||||
- title: Core
|
|
||||||
desc: Core functionality for training
|
|
||||||
contents:
|
|
||||||
- train
|
|
||||||
- evaluate
|
|
||||||
- datasets
|
|
||||||
- convert
|
|
||||||
- prompt_tokenizers
|
|
||||||
- logging_config
|
|
||||||
- core.trainer_builder
|
|
||||||
- core.training_args
|
|
||||||
- core.chat.messages
|
|
||||||
- core.chat.format.chatml
|
|
||||||
- core.chat.format.llama3x
|
|
||||||
- core.chat.format.shared
|
|
||||||
- core.datasets.chat
|
|
||||||
- core.datasets.transforms.chat_builder
|
|
||||||
- title: CLI
|
|
||||||
desc: Command-line interface
|
|
||||||
contents:
|
|
||||||
- cli.main
|
|
||||||
- cli.train
|
|
||||||
- cli.evaluate
|
|
||||||
- cli.args
|
|
||||||
- cli.checks
|
|
||||||
- cli.config
|
|
||||||
- cli.inference
|
|
||||||
- cli.merge_lora
|
|
||||||
- cli.merge_sharded_fsdp_weights
|
|
||||||
- cli.preprocess
|
|
||||||
- cli.sweeps
|
|
||||||
- cli.utils
|
|
||||||
- cli.cloud.base
|
|
||||||
- cli.cloud.modal_
|
|
||||||
- title: Trainers
|
|
||||||
desc: Training implementations
|
|
||||||
contents:
|
|
||||||
- core.trainers.base
|
|
||||||
- core.trainers.trl
|
|
||||||
- core.trainers.dpo.trainer
|
|
||||||
- core.trainers.grpo.trainer
|
|
||||||
- title: Prompt Strategies
|
|
||||||
desc: Prompt formatting strategies
|
|
||||||
contents:
|
|
||||||
- prompt_strategies.base
|
|
||||||
- prompt_strategies.chat_template
|
|
||||||
- prompt_strategies.alpaca_chat
|
|
||||||
- prompt_strategies.alpaca_instruct
|
|
||||||
- prompt_strategies.alpaca_w_system
|
|
||||||
- prompt_strategies.user_defined
|
|
||||||
- prompt_strategies.llama2_chat
|
|
||||||
- prompt_strategies.completion
|
|
||||||
- prompt_strategies.input_output
|
|
||||||
- prompt_strategies.stepwise_supervised
|
|
||||||
- prompt_strategies.metharme
|
|
||||||
- prompt_strategies.orcamini
|
|
||||||
- prompt_strategies.pygmalion
|
|
||||||
- prompt_strategies.messages.chat
|
|
||||||
- prompt_strategies.dpo.chat_template
|
|
||||||
- prompt_strategies.dpo.llama3
|
|
||||||
- prompt_strategies.dpo.chatml
|
|
||||||
- prompt_strategies.dpo.zephyr
|
|
||||||
- prompt_strategies.dpo.user_defined
|
|
||||||
- prompt_strategies.dpo.passthrough
|
|
||||||
- prompt_strategies.kto.llama3
|
|
||||||
- prompt_strategies.kto.chatml
|
|
||||||
- prompt_strategies.kto.user_defined
|
|
||||||
- prompt_strategies.orpo.chat_template
|
|
||||||
- prompt_strategies.bradley_terry.llama3
|
|
||||||
- title: Kernels
|
|
||||||
desc: Low-level performance optimizations
|
|
||||||
contents:
|
|
||||||
- kernels.lora
|
|
||||||
- kernels.geglu
|
|
||||||
- kernels.swiglu
|
|
||||||
- kernels.quantize
|
|
||||||
- kernels.utils
|
|
||||||
- title: MonkeyPatches
|
|
||||||
desc: Runtime patches for model optimizations
|
|
||||||
contents:
|
|
||||||
- monkeypatch.llama_attn_hijack_flash
|
|
||||||
- monkeypatch.llama_attn_hijack_xformers
|
|
||||||
- monkeypatch.mistral_attn_hijack_flash
|
|
||||||
- monkeypatch.multipack
|
|
||||||
- monkeypatch.relora
|
|
||||||
- monkeypatch.llama_expand_mask
|
|
||||||
- monkeypatch.lora_kernels
|
|
||||||
- monkeypatch.utils
|
|
||||||
- monkeypatch.btlm_attn_hijack_flash
|
|
||||||
- monkeypatch.llama_patch_multipack
|
|
||||||
- monkeypatch.stablelm_attn_hijack_flash
|
|
||||||
- monkeypatch.trainer_fsdp_optim
|
|
||||||
- monkeypatch.transformers_fa_utils
|
|
||||||
- monkeypatch.unsloth_
|
|
||||||
- monkeypatch.attention.mllama
|
|
||||||
- monkeypatch.data.batch_dataset_fetcher
|
|
||||||
- monkeypatch.mixtral
|
|
||||||
- title: Utils
|
|
||||||
desc: Utility functions
|
|
||||||
contents:
|
|
||||||
- utils.models
|
|
||||||
- utils.tokenization
|
|
||||||
- utils.chat_templates
|
|
||||||
- utils.lora
|
|
||||||
- utils.lora_embeddings
|
|
||||||
- utils.model_shard_quant
|
|
||||||
- utils.bench
|
|
||||||
- utils.freeze
|
|
||||||
- utils.trainer
|
|
||||||
- utils.schedulers
|
|
||||||
- utils.distributed
|
|
||||||
- utils.dict
|
|
||||||
- utils.optimizers.adopt
|
|
||||||
- utils.data.pretraining
|
|
||||||
- utils.data.sft
|
|
||||||
- utils.gradient_checkpointing.unsloth
|
|
||||||
- title: Schemas
|
|
||||||
desc: Pydantic data models for Axolotl config
|
|
||||||
contents:
|
|
||||||
- utils.schemas.config
|
|
||||||
- utils.schemas.model
|
|
||||||
- utils.schemas.training
|
|
||||||
- utils.schemas.datasets
|
|
||||||
- utils.schemas.peft
|
|
||||||
- utils.schemas.trl
|
|
||||||
- utils.schemas.integrations
|
|
||||||
- utils.schemas.enums
|
|
||||||
- utils.schemas.utils
|
|
||||||
- title: Integrations
|
|
||||||
desc: Third-party integrations and extensions
|
|
||||||
contents:
|
|
||||||
- integrations.base
|
|
||||||
- integrations.cut_cross_entropy.args
|
|
||||||
- integrations.grokfast.optimizer
|
|
||||||
- integrations.kd.trainer
|
|
||||||
- integrations.liger.args
|
|
||||||
- integrations.lm_eval.args
|
|
||||||
- integrations.spectrum.args
|
|
||||||
- title: Common
|
|
||||||
desc: Common utilities and shared functionality
|
|
||||||
contents:
|
|
||||||
- common.architectures
|
|
||||||
- common.const
|
|
||||||
- common.datasets
|
|
||||||
- title: Models
|
|
||||||
desc: Custom model implementations
|
|
||||||
contents:
|
|
||||||
- models.mamba.modeling_mamba
|
|
||||||
- title: Data Processing
|
|
||||||
desc: Data processing utilities
|
|
||||||
contents:
|
|
||||||
- utils.collators.core
|
|
||||||
- utils.collators.batching
|
|
||||||
- utils.collators.mamba
|
|
||||||
- utils.collators.mm_chat
|
|
||||||
- utils.samplers.multipack
|
|
||||||
- title: Callbacks
|
|
||||||
desc: Training callbacks
|
|
||||||
contents:
|
|
||||||
- utils.callbacks.perplexity
|
|
||||||
- utils.callbacks.profiler
|
|
||||||
- utils.callbacks.lisa
|
|
||||||
- utils.callbacks.mlflow_
|
|
||||||
- utils.callbacks.comet_
|
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -207,8 +35,6 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
- text: "API Reference"
|
|
||||||
href: docs/api
|
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -254,22 +80,3 @@ format:
|
|||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
# Enable better handling of line breaks in markdown
|
|
||||||
preserve-tabs: true
|
|
||||||
html-math-method: mathjax
|
|
||||||
# Improved markdown processing options
|
|
||||||
md-extensions:
|
|
||||||
- markdown_it
|
|
||||||
- def_list
|
|
||||||
- attr_list
|
|
||||||
- fenced_divs
|
|
||||||
- tables
|
|
||||||
- html_admonition
|
|
||||||
- lineblocks
|
|
||||||
- fancy_lists
|
|
||||||
# Control whitespace handling
|
|
||||||
whitespace: preserve
|
|
||||||
# Process newlines in paragraphs
|
|
||||||
wrap: preserve
|
|
||||||
# Better line break handling
|
|
||||||
preserve-linebreaks: true
|
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,4 +1,2 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
/api/*.qmd
|
|
||||||
/api/*.html
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Command Line Interface (CLI)"
|
title: "CLI Reference"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](dataset-formats) and prompt strategies to:
|
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
@@ -2,5 +2,3 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
quartodoc
|
|
||||||
jupyter
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""Builder for the training args and trainer"""
|
"""
|
||||||
|
Builder for the training args and trainer
|
||||||
|
"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -83,8 +85,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.schemas.trl import TRLConfig
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
|
||||||
from transformers.trainer import Trainer
|
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
@@ -27,18 +25,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset.
|
"""Helper function to evaluate a single dataset safely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance.
|
trainer: The trainer instance
|
||||||
dataset: Dataset to evaluate.
|
dataset: Dataset to evaluate
|
||||||
dataset_type: Type of dataset ('train' or 'eval').
|
dataset_type: Type of dataset ('train' or 'eval')
|
||||||
flash_optimum: Whether to use flash optimum.
|
flash_optimum: Whether to use flash optimum
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None.
|
Dictionary of metrics or None if dataset is None
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -65,14 +63,17 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets.
|
Evaluate a model on training and validation datasets
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary mapping metric names to their values.
|
Tuple containing:
|
||||||
|
- The model (either PeftModel or PreTrainedModel)
|
||||||
|
- The tokenizer
|
||||||
|
- Dictionary of evaluation metrics
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
|||||||
@@ -11,17 +11,19 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Module to handle merging the plugins' input arguments with the base configurations.
|
module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
This was moved here to prevent circular imports.
|
this was moved here to prevent circular imports
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -42,7 +43,6 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
@@ -12,13 +12,19 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
from axolotl.utils.models import load_model_config
|
|
||||||
from axolotl.utils.schemas.config import (
|
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||||
|
)
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
|
DPODataset,
|
||||||
|
KTODataset,
|
||||||
|
SFTDataset,
|
||||||
|
)
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.models import load_model_config
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,8 @@
|
|||||||
"""Pydantic models for TRL trainer configuration"""
|
"""
|
||||||
|
GRPO specific configuration args
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -8,11 +12,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: float | None = Field(
|
beta: Optional[float] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: int | None = Field(
|
max_completion_length: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -21,50 +25,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: bool | None = Field(
|
use_vllm: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: str | None = Field(
|
vllm_device: Optional[str] = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: float | None = Field(
|
vllm_gpu_memory_utilization: Optional[float] = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: str | None = Field(
|
vllm_dtype: Optional[str] = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: int | None = Field(
|
vllm_max_model_len: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: list[str] | None = Field(
|
reward_funcs: Optional[list[str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: list[float] | None = Field(
|
reward_weights: Optional[list[float]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: int | None = Field(
|
num_generations: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: bool | None = Field(
|
log_completions: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: bool | None = Field(
|
sync_ref_model: Optional[bool] = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -73,13 +77,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: float | None = Field(
|
ref_model_mixup_alpha: Optional[float] = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: int | None = Field(
|
ref_model_sync_steps: Optional[int] = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Pydantic models for datasets-related configuration"""
|
|
||||||
|
|
||||||
from pydantic import BaseModel, model_validator
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import ChatTemplate
|
|
||||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
|
||||||
"""Structure for user defined prompt types"""
|
|
||||||
|
|
||||||
system_prompt: str | None = None
|
|
||||||
system_format: str | None = None
|
|
||||||
field_system: str | None = None
|
|
||||||
field_instruction: str | None = None
|
|
||||||
field_input: str | None = None
|
|
||||||
field_output: str | None = None
|
|
||||||
|
|
||||||
format: str | None = None
|
|
||||||
no_input_format: str | None = None
|
|
||||||
field: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SFTDataset(BaseModel):
|
|
||||||
"""SFT configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: str | UserDefinedPrompterType | None = None
|
|
||||||
input_transform: str | None = None
|
|
||||||
shards: int | None = None
|
|
||||||
shards_idx: int | None = None
|
|
||||||
preprocess_shards: int | None = None
|
|
||||||
conversation: str | None = None
|
|
||||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
|
||||||
chat_template: ChatTemplate | str | None = None
|
|
||||||
chat_template_jinja: str | None = None
|
|
||||||
data_files: str | list[str] | None = None
|
|
||||||
input_format: str | None = None
|
|
||||||
name: str | None = None
|
|
||||||
ds_type: str | None = None
|
|
||||||
train_on_split: str | None = None
|
|
||||||
field: str | None = None
|
|
||||||
field_human: str | None = None
|
|
||||||
field_model: str | None = None
|
|
||||||
field_messages: str | None = None
|
|
||||||
# deprecated, use message_property_mappings
|
|
||||||
message_field_role: str | None = None
|
|
||||||
# deprecated, use message_property_mappings
|
|
||||||
message_field_content: str | None = None
|
|
||||||
message_property_mappings: dict[str, str] | None = None
|
|
||||||
message_field_training: str | None = None
|
|
||||||
message_field_training_detail: str | None = None
|
|
||||||
logprobs_field: str | None = None
|
|
||||||
temperature: float | None = None
|
|
||||||
roles_to_train: list[str] | None = None
|
|
||||||
train_on_eos: str | None = None
|
|
||||||
roles: dict[str, list[str]] | None = None
|
|
||||||
drop_system_message: bool | None = None
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
revision: str | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def handle_legacy_message_fields(cls, data):
|
|
||||||
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
|
||||||
return handle_legacy_message_fields_logic(data)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def check_chat_template_config(cls, data):
|
|
||||||
if isinstance(data, BaseModel):
|
|
||||||
data = data.model_dump()
|
|
||||||
|
|
||||||
# Set chat_template to tokenizer_default if not set
|
|
||||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
|
||||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
|
||||||
|
|
||||||
# if chat_template is set to jinja, chat_template_jinja is required
|
|
||||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
|
||||||
"chat_template_jinja"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"chat_template_jinja is required when chat_template is set to jinja"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If chat_template_jinja is set, set chat_template to jinja
|
|
||||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
|
||||||
data["chat_template"] = ChatTemplate.jinja
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class PretrainingDataset(BaseModel):
|
|
||||||
"""Pretraining dataset configuration subset"""
|
|
||||||
|
|
||||||
name: str | None = None
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = "train"
|
|
||||||
text_column: str | None = "text"
|
|
||||||
type: str | None = "pretrain"
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
data_files: str | None = None
|
|
||||||
skip: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
|
||||||
"""User defined typing for DPO"""
|
|
||||||
|
|
||||||
field_system: str | None = None
|
|
||||||
field_prompt: str | None = None
|
|
||||||
field_chosen: str | None = None
|
|
||||||
field_rejected: str | None = None
|
|
||||||
prompt_format: str | None = None
|
|
||||||
chosen_format: str | None = None
|
|
||||||
rejected_format: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class DPODataset(BaseModel):
|
|
||||||
"""DPO configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: UserDefinedDPOType | str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
revision: str | None = None
|
|
||||||
field_messages: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class StepwiseSupervisedDataset(BaseModel):
|
|
||||||
"""Stepwise supervised dataset configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
revision: str | None = None
|
|
||||||
step_separator: str | None = None
|
|
||||||
max_completion_length: int | None = None
|
|
||||||
train_on_last_step_only: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedKTOType(BaseModel):
|
|
||||||
"""User defined typing for KTO"""
|
|
||||||
|
|
||||||
field_system: str | None = None
|
|
||||||
field_prompt: str | None = None
|
|
||||||
field_completion: str | None = None
|
|
||||||
field_label: bool | None = None
|
|
||||||
prompt_format: str | None = None
|
|
||||||
completion_format: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class KTODataset(BaseModel):
|
|
||||||
"""KTO configuration subset"""
|
|
||||||
|
|
||||||
path: str | None = None
|
|
||||||
split: str | None = None
|
|
||||||
type: UserDefinedKTOType | str | None = None
|
|
||||||
data_files: list[str] | None = None
|
|
||||||
trust_remote_code: bool | None = False
|
|
||||||
revision: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
"""Pydantic models for deprecated and remapped configuration parameters"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
|
||||||
"""configurations that are deprecated"""
|
|
||||||
|
|
||||||
max_packed_sequence_len: int | None = None
|
|
||||||
rope_scaling: Any | None = None
|
|
||||||
noisy_embedding_alpha: float | None = None
|
|
||||||
dpo_beta: float | None = None
|
|
||||||
evaluation_strategy: str | None = None
|
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
|
||||||
@classmethod
|
|
||||||
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
|
||||||
if max_packed_sequence_len:
|
|
||||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
|
||||||
return max_packed_sequence_len
|
|
||||||
|
|
||||||
@field_validator("rope_scaling")
|
|
||||||
@classmethod
|
|
||||||
def validate_rope_scaling(cls, rope_scaling):
|
|
||||||
if rope_scaling:
|
|
||||||
raise DeprecationWarning(
|
|
||||||
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
|
||||||
)
|
|
||||||
return rope_scaling
|
|
||||||
|
|
||||||
@field_validator("noisy_embedding_alpha")
|
|
||||||
@classmethod
|
|
||||||
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
|
||||||
if noisy_embedding_alpha:
|
|
||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
|
||||||
return noisy_embedding_alpha
|
|
||||||
|
|
||||||
@field_validator("dpo_beta")
|
|
||||||
@classmethod
|
|
||||||
def validate_dpo_beta(cls, dpo_beta):
|
|
||||||
if dpo_beta is not None:
|
|
||||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
|
||||||
return dpo_beta
|
|
||||||
|
|
||||||
@field_validator("evaluation_strategy")
|
|
||||||
@classmethod
|
|
||||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
|
||||||
if evaluation_strategy is not None:
|
|
||||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
|
||||||
return evaluation_strategy
|
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
|
||||||
"""Parameters that have been remapped to other names"""
|
|
||||||
|
|
||||||
overrides_of_model_config: dict[str, Any] | None = Field(
|
|
||||||
default=None, alias="model_config"
|
|
||||||
)
|
|
||||||
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
|
||||||
default=None, alias="model_kwargs"
|
|
||||||
)
|
|
||||||
type_of_model: str | None = Field(default=None, alias="model_type")
|
|
||||||
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
"""Enums for Axolotl input config"""
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
|
||||||
"""RL trainer type configuration subset"""
|
|
||||||
|
|
||||||
dpo = "dpo" # pylint: disable=invalid-name
|
|
||||||
grpo = "grpo" # pylint: disable=invalid-name
|
|
||||||
ipo = "ipo" # pylint: disable=invalid-name
|
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
|
||||||
simpo = "simpo" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
|
||||||
"""Chat templates configuration subset"""
|
|
||||||
|
|
||||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
|
||||||
chatml = "chatml" # pylint: disable=invalid-name
|
|
||||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
|
||||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
|
||||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
|
||||||
gemma = "gemma" # pylint: disable=invalid-name
|
|
||||||
cohere = "cohere" # pylint: disable=invalid-name
|
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
|
||||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
|
||||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
|
||||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
|
||||||
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
|
||||||
jamba = "jamba" # pylint: disable=invalid-name
|
|
||||||
jinja = "jinja" # pylint: disable=invalid-name
|
|
||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
|
||||||
metharme = "metharme" # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
|
||||||
class CustomSupportedOptimizers(str, Enum):
|
|
||||||
"""Custom supported optimizers"""
|
|
||||||
|
|
||||||
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
|
||||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
|
||||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
|
||||||
muon = "muon" # pylint: disable=invalid-name
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
"""Pydantic models for Axolotl integrations"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MLFlowConfig(BaseModel):
|
|
||||||
"""MLFlow configuration subset"""
|
|
||||||
|
|
||||||
use_mlflow: bool | None = None
|
|
||||||
mlflow_tracking_uri: str | None = None
|
|
||||||
mlflow_experiment_name: str | None = None
|
|
||||||
mlflow_run_name: str | None = None
|
|
||||||
hf_mlflow_log_artifacts: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LISAConfig(BaseModel):
|
|
||||||
"""LISA configuration subset"""
|
|
||||||
|
|
||||||
lisa_n_layers: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "the number of activate layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_step_interval: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "how often to switch layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_layers_attribute: str | None = Field(
|
|
||||||
default="model.layers",
|
|
||||||
json_schema_extra={"description": "path under the model to access the layers"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WandbConfig(BaseModel):
|
|
||||||
"""Wandb configuration subset"""
|
|
||||||
|
|
||||||
use_wandb: bool | None = None
|
|
||||||
wandb_name: str | None = None
|
|
||||||
wandb_run_id: str | None = None
|
|
||||||
wandb_mode: str | None = None
|
|
||||||
wandb_project: str | None = None
|
|
||||||
wandb_entity: str | None = None
|
|
||||||
wandb_watch: str | None = None
|
|
||||||
wandb_log_model: str | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_wandb_run(cls, data):
|
|
||||||
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
|
||||||
data["wandb_name"] = data.get("wandb_run_id")
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class CometConfig(BaseModel):
|
|
||||||
"""Comet configuration subset"""
|
|
||||||
|
|
||||||
use_comet: bool | None = None
|
|
||||||
comet_api_key: str | None = None
|
|
||||||
comet_workspace: str | None = None
|
|
||||||
comet_project_name: str | None = None
|
|
||||||
comet_experiment_key: str | None = None
|
|
||||||
comet_mode: str | None = None
|
|
||||||
comet_online: bool | None = None
|
|
||||||
comet_experiment_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class GradioConfig(BaseModel):
|
|
||||||
"""Gradio configuration subset"""
|
|
||||||
|
|
||||||
gradio_title: str | None = None
|
|
||||||
gradio_share: bool | None = None
|
|
||||||
gradio_server_name: str | None = None
|
|
||||||
gradio_server_port: int | None = None
|
|
||||||
gradio_max_new_tokens: int | None = None
|
|
||||||
gradio_temperature: float | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RayConfig(BaseModel):
|
|
||||||
"""Ray launcher configuration subset"""
|
|
||||||
|
|
||||||
use_ray: bool = Field(default=False)
|
|
||||||
ray_run_name: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The training results will be saved at `saves/ray_run_name`."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ray_num_workers: int = Field(
|
|
||||||
default=1,
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The number of workers for Ray training. Default is 1 worker."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
resources_per_worker: dict = Field(
|
|
||||||
default_factory=lambda: {"GPU": 1},
|
|
||||||
json_schema_extra={
|
|
||||||
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
"""Pydantic models for model input / output, etc. configuration"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelInputConfig(BaseModel):
|
|
||||||
"""Model configuration subset"""
|
|
||||||
|
|
||||||
model_config = {"protected_namespaces": ()}
|
|
||||||
|
|
||||||
base_model: str
|
|
||||||
base_model_config: str | None = None
|
|
||||||
cls_model_config: str | None = None
|
|
||||||
tokenizer_config: str | None = None
|
|
||||||
tokenizer_use_fast: bool | None = None
|
|
||||||
tokenizer_legacy: bool | None = None
|
|
||||||
tokenizer_type: str | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
|
||||||
)
|
|
||||||
processor_type: str | None = Field(
|
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
|
||||||
)
|
|
||||||
trust_remote_code: bool | None = None
|
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
|
||||||
@classmethod
|
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
|
||||||
if trust_remote_code:
|
|
||||||
LOG.warning(
|
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
||||||
)
|
|
||||||
return trust_remote_code
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
|
||||||
"""model save configuration subset"""
|
|
||||||
|
|
||||||
output_dir: str = Field(default="./model-out")
|
|
||||||
hub_model_id: str | None = None
|
|
||||||
hub_strategy: str | None = None
|
|
||||||
save_safetensors: bool | None = True
|
|
||||||
|
|
||||||
|
|
||||||
class SpecialTokensConfig(BaseModel):
|
|
||||||
"""Special tokens configuration subset"""
|
|
||||||
|
|
||||||
bos_token: str | None = None
|
|
||||||
eos_token: str | None = None
|
|
||||||
pad_token: str | None = None
|
|
||||||
unk_token: str | None = None
|
|
||||||
additional_special_tokens: list[str] | None = None
|
|
||||||
@@ -1,132 +0,0 @@
|
|||||||
"""Pydantic models for PEFT-related configuration"""
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
|
||||||
"""LoftQ configuration subset"""
|
|
||||||
|
|
||||||
loftq_bits: int = Field(
|
|
||||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
|
||||||
)
|
|
||||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
|
||||||
|
|
||||||
|
|
||||||
class PeftConfig(BaseModel):
|
|
||||||
"""peftq configuration subset"""
|
|
||||||
|
|
||||||
loftq_config: LoftQConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LoraConfig(BaseModel):
|
|
||||||
"""Peft / LoRA configuration subset"""
|
|
||||||
|
|
||||||
load_in_8bit: bool | None = Field(default=False)
|
|
||||||
load_in_4bit: bool | None = Field(default=False)
|
|
||||||
|
|
||||||
adapter: str | None = None
|
|
||||||
lora_model_dir: str | None = None
|
|
||||||
lora_r: int | None = None
|
|
||||||
lora_alpha: int | None = None
|
|
||||||
lora_fan_in_fan_out: bool | None = None
|
|
||||||
lora_target_modules: str | list[str] | None = None
|
|
||||||
lora_target_linear: bool | None = None
|
|
||||||
lora_modules_to_save: list[str] | None = None
|
|
||||||
lora_dropout: float | None = 0.0
|
|
||||||
peft_layers_to_transform: list[int] | None = None
|
|
||||||
peft_layers_pattern: list[str] | None = None
|
|
||||||
peft: PeftConfig | None = None
|
|
||||||
peft_use_dora: bool | None = None
|
|
||||||
peft_use_rslora: bool | None = None
|
|
||||||
peft_layer_replication: list[tuple[int, int]] | None = None
|
|
||||||
peft_init_lora_weights: bool | str | None = None
|
|
||||||
|
|
||||||
qlora_sharded_model_loading: bool | None = Field(
|
|
||||||
default=False,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
lora_on_cpu: bool | None = None
|
|
||||||
gptq: bool | None = None
|
|
||||||
bnb_config_kwargs: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
loraplus_lr_ratio: float | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
loraplus_lr_embedding: float | None = Field(
|
|
||||||
default=1e-6,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "loraplus learning rate for lora embedding layers."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
merge_lora: bool | None = None
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_adapter(cls, data):
|
|
||||||
if (
|
|
||||||
not data.get("adapter")
|
|
||||||
and not data.get("inference")
|
|
||||||
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def validate_qlora(self):
|
|
||||||
if self.adapter == "qlora":
|
|
||||||
if self.merge_lora:
|
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
|
||||||
if self.load_in_8bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
|
||||||
|
|
||||||
if self.gptq:
|
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
|
||||||
|
|
||||||
if self.load_in_4bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if self.load_in_8bit:
|
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
|
||||||
|
|
||||||
if self.gptq:
|
|
||||||
raise ValueError("Can't load qlora if gptq")
|
|
||||||
|
|
||||||
if not self.load_in_4bit:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
return self
|
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
|
||||||
@classmethod
|
|
||||||
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
|
||||||
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
|
||||||
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
|
||||||
return loraplus_lr_embedding
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_lora_dropout(cls, data):
|
|
||||||
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
|
||||||
data["lora_dropout"] = 0.0
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
|
||||||
"""ReLoRA configuration subset"""
|
|
||||||
|
|
||||||
relora_steps: int | None = None
|
|
||||||
relora_warmup_steps: int | None = None
|
|
||||||
relora_anneal_steps: int | None = None
|
|
||||||
relora_prune_ratio: float | None = None
|
|
||||||
relora_cpu_offload: bool | None = None
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
"""Pydantic models for training hyperparameters"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
from transformers import SchedulerType
|
|
||||||
from transformers.training_args import OptimizerNames
|
|
||||||
|
|
||||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LrGroup(BaseModel):
|
|
||||||
"""Custom learning rate group configuration"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
modules: list[str]
|
|
||||||
lr: float
|
|
||||||
|
|
||||||
|
|
||||||
class HyperparametersConfig(BaseModel):
|
|
||||||
"""Training hyperparams configuration subset"""
|
|
||||||
|
|
||||||
gradient_accumulation_steps: int | None = Field(default=1)
|
|
||||||
micro_batch_size: int | None = Field(
|
|
||||||
default=1,
|
|
||||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
|
||||||
)
|
|
||||||
batch_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "Total batch size, we do not recommended setting this manually"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
eval_batch_size: int | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
auto_find_batch_size: bool | None = None
|
|
||||||
|
|
||||||
train_on_inputs: bool | None = False
|
|
||||||
group_by_length: bool | None = None
|
|
||||||
|
|
||||||
learning_rate: str | float
|
|
||||||
embedding_lr: float | None = None
|
|
||||||
embedding_lr_scale: float | None = None
|
|
||||||
weight_decay: float | None = 0.0
|
|
||||||
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
|
||||||
OptimizerNames.ADAMW_TORCH_FUSED
|
|
||||||
)
|
|
||||||
optim_args: (str | dict[str, Any]) | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
|
||||||
)
|
|
||||||
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
|
||||||
default=None,
|
|
||||||
json_schema_extra={
|
|
||||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
torchdistx_path: str | None = None
|
|
||||||
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
|
||||||
SchedulerType.COSINE
|
|
||||||
)
|
|
||||||
lr_scheduler_kwargs: dict[str, Any] | None = None
|
|
||||||
lr_quadratic_warmup: bool | None = None
|
|
||||||
cosine_min_lr_ratio: float | None = None
|
|
||||||
cosine_constant_lr_ratio: float | None = None
|
|
||||||
lr_div_factor: float | None = None
|
|
||||||
lr_groups: list[LrGroup] | None = None
|
|
||||||
|
|
||||||
adam_epsilon: float | None = None
|
|
||||||
adam_beta1: float | None = None
|
|
||||||
adam_beta2: float | None = None
|
|
||||||
max_grad_norm: float | None = None
|
|
||||||
num_epochs: float = Field(default=1.0)
|
|
||||||
|
|
||||||
@field_validator("batch_size")
|
|
||||||
@classmethod
|
|
||||||
def hint_batch_size_set(cls, batch_size):
|
|
||||||
if batch_size:
|
|
||||||
LOG.warning(
|
|
||||||
"%s\n%s",
|
|
||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
|
||||||
)
|
|
||||||
return batch_size
|
|
||||||
|
|
||||||
@field_validator("learning_rate")
|
|
||||||
@classmethod
|
|
||||||
def convert_learning_rate(cls, learning_rate):
|
|
||||||
if learning_rate and isinstance(learning_rate, str):
|
|
||||||
learning_rate = float(learning_rate)
|
|
||||||
return learning_rate
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Utilities for Axolotl Pydantic models"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
|
||||||
|
|
||||||
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
|
||||||
- message_field_role: Mapped to the role field
|
|
||||||
- message_field_content: Mapped to the content field
|
|
||||||
|
|
||||||
The new system uses message_property_mappings to support arbitrary field mappings:
|
|
||||||
message_property_mappings:
|
|
||||||
role: source_role_field
|
|
||||||
content: source_content_field
|
|
||||||
additional_field: source_field
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Dictionary containing configuration data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated dictionary with message field mappings consolidated
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If there are conflicts between legacy and new mappings
|
|
||||||
"""
|
|
||||||
data = data.copy() # Create a copy to avoid modifying the original
|
|
||||||
|
|
||||||
if data.get("message_property_mappings") is None:
|
|
||||||
data["message_property_mappings"] = {}
|
|
||||||
|
|
||||||
# Check for conflicts and handle role
|
|
||||||
if "message_field_role" in data:
|
|
||||||
LOG.warning(
|
|
||||||
"message_field_role is deprecated, use message_property_mappings instead. "
|
|
||||||
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"role" in data["message_property_mappings"]
|
|
||||||
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
|
||||||
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
|
||||||
)
|
|
||||||
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
|
||||||
|
|
||||||
del data["message_field_role"]
|
|
||||||
elif "role" not in data["message_property_mappings"]:
|
|
||||||
data["message_property_mappings"]["role"] = "role"
|
|
||||||
|
|
||||||
# Check for conflicts and handle content
|
|
||||||
if "message_field_content" in data:
|
|
||||||
LOG.warning(
|
|
||||||
"message_field_content is deprecated, use message_property_mappings instead. "
|
|
||||||
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"content" in data["message_property_mappings"]
|
|
||||||
and data["message_property_mappings"]["content"]
|
|
||||||
!= data["message_field_content"]
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
|
||||||
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
|
||||||
)
|
|
||||||
data["message_property_mappings"]["content"] = (
|
|
||||||
data["message_field_content"] or "content"
|
|
||||||
)
|
|
||||||
|
|
||||||
del data["message_field_content"]
|
|
||||||
elif "content" not in data["message_property_mappings"]:
|
|
||||||
data["message_property_mappings"]["content"] = "content"
|
|
||||||
|
|
||||||
return data
|
|
||||||
90
styles.css
90
styles.css
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 3rem;
|
font-size: 5rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 1.5rem;
|
font-size: 2rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.25rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
@@ -191,87 +191,3 @@ code span.er {
|
|||||||
color: #5cb85c !important;
|
color: #5cb85c !important;
|
||||||
text-decoration: none !important;
|
text-decoration: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* API Documentation Styling */
|
|
||||||
|
|
||||||
/* Improve docstring section rendering */
|
|
||||||
.level3 p {
|
|
||||||
white-space: pre-line !important;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Format docstring sections */
|
|
||||||
.level3 p strong {
|
|
||||||
display: block;
|
|
||||||
margin-top: 1em;
|
|
||||||
font-weight: bold;
|
|
||||||
color: var(--cyan);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Add spacing after sections */
|
|
||||||
.level3 p:has(strong) {
|
|
||||||
margin-bottom: 0.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Format Args and Returns sections */
|
|
||||||
p:has(code) {
|
|
||||||
line-height: 1.6;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Function signatures */
|
|
||||||
.sourceCode {
|
|
||||||
margin-bottom: 1.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Parameter tables */
|
|
||||||
.doc-section-parameters table,
|
|
||||||
.doc-section-returns table {
|
|
||||||
margin-top: 1em;
|
|
||||||
margin-bottom: 1.5em;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Make parameter and returns headers smaller */
|
|
||||||
h2.anchored[data-anchor-id="parameters"],
|
|
||||||
h2.anchored[data-anchor-id="returns"],
|
|
||||||
.doc-section-parameters h4,
|
|
||||||
.doc-section-returns h4 {
|
|
||||||
font-size: 1.25rem;
|
|
||||||
margin-top: 2rem;
|
|
||||||
margin-bottom: 1rem;
|
|
||||||
color: var(--lime);
|
|
||||||
border-bottom: 1px solid var(--lime);
|
|
||||||
padding-bottom: 0.3rem;
|
|
||||||
font-family: var(--font-body);
|
|
||||||
font-weight: 500;
|
|
||||||
letter-spacing: normal;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Style documentation tables */
|
|
||||||
table {
|
|
||||||
width: 100%;
|
|
||||||
margin-bottom: 1.5rem;
|
|
||||||
border-collapse: collapse;
|
|
||||||
}
|
|
||||||
|
|
||||||
table th {
|
|
||||||
background-color: #1a1a1a;
|
|
||||||
padding: 0.5rem 1rem;
|
|
||||||
border-bottom: 2px solid var(--greige-600);
|
|
||||||
text-align: left;
|
|
||||||
}
|
|
||||||
|
|
||||||
table td {
|
|
||||||
padding: 0.5rem 1rem;
|
|
||||||
border-bottom: 1px solid var(--greige-600);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Code in table cells */
|
|
||||||
table td code {
|
|
||||||
background-color: transparent !important;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Improve spacing in parameter and return tables */
|
|
||||||
.doc-section-parameters,
|
|
||||||
.doc-section-returns {
|
|
||||||
margin-top: 1rem;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.schemas.datasets import ChatTemplate
|
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
|||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
"""Test message property mapping configuration validation"""
|
"""Test message property mapping configuration validation"""
|
||||||
from axolotl.utils.schemas.datasets import SFTDataset
|
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
||||||
|
|
||||||
# Test legacy fields are mapped orrectly
|
# Test legacy fields are mapped orrectly
|
||||||
dataset = SFTDataset(
|
dataset = SFTDataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user