Compare commits
10 Commits
5b7e688fc5
...
quartodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bffef25d0 | ||
|
|
94c00c1d04 | ||
|
|
ddd84d7c65 | ||
|
|
42bdf0bd74 | ||
|
|
b03d96a228 | ||
|
|
2653f170fc | ||
|
|
3bfcce9f0a | ||
|
|
8feb746953 | ||
|
|
a563815fe7 | ||
|
|
81f2203151 |
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter
|
python3 -m pip install jupyter quartodoc
|
||||||
|
python3 -m pip install -e .
|
||||||
|
- name: Build autodoc
|
||||||
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/api/) - Auto-generated code documentation
|
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
18
_quarto.yml
18
_quarto.yml
@@ -124,6 +124,18 @@ quartodoc:
|
|||||||
- utils.data.pretraining
|
- utils.data.pretraining
|
||||||
- utils.data.sft
|
- utils.data.sft
|
||||||
- utils.gradient_checkpointing.unsloth
|
- utils.gradient_checkpointing.unsloth
|
||||||
|
- title: Schemas
|
||||||
|
desc: Pydantic data models for Axolotl config
|
||||||
|
contents:
|
||||||
|
- utils.schemas.config
|
||||||
|
- utils.schemas.model
|
||||||
|
- utils.schemas.training
|
||||||
|
- utils.schemas.datasets
|
||||||
|
- utils.schemas.peft
|
||||||
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.integrations
|
||||||
|
- utils.schemas.enums
|
||||||
|
- utils.schemas.utils
|
||||||
- title: Integrations
|
- title: Integrations
|
||||||
desc: Third-party integrations and extensions
|
desc: Third-party integrations and extensions
|
||||||
contents:
|
contents:
|
||||||
@@ -195,12 +207,8 @@ website:
|
|||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/config.qmd
|
||||||
|
|
||||||
- section: "Reference"
|
|
||||||
contents:
|
|
||||||
- docs/config.qmd
|
|
||||||
- text: "API Reference"
|
- text: "API Reference"
|
||||||
href: docs/api/index.qmd
|
href: docs/api
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
jupyter
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from axolotl.cli.utils import (
|
|||||||
)
|
)
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""Builder for the training args and trainer"""
|
||||||
Builder for the training args and trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
@@ -11,19 +11,17 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module to handle merging the plugins' input arguments with the base configurations.
|
Module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
this was moved here to prevent circular imports
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
|||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||||
|
|
||||||
# Configure the logger
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
def default(
|
def default(
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ from trl.models import unwrap_model_for_generation
|
|||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
broadcast_dict,
|
broadcast_dict,
|
||||||
@@ -43,6 +42,7 @@ from axolotl.utils.distributed import (
|
|||||||
is_main_process,
|
is_main_process,
|
||||||
zero_first,
|
zero_first,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|||||||
@@ -12,19 +12,13 @@ from transformers.utils.import_utils import is_torch_npu_available
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
||||||
DPODataset,
|
|
||||||
KTODataset,
|
|
||||||
SFTDataset,
|
|
||||||
)
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model_config
|
from axolotl.utils.models import load_model_config
|
||||||
|
from axolotl.utils.schemas.config import (
|
||||||
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
|
)
|
||||||
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
|
from axolotl.utils.schemas.datasets import DPODataset, KTODataset, SFTDataset
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
165
src/axolotl/utils/schemas/datasets.py
Normal file
165
src/axolotl/utils/schemas/datasets.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""Pydantic models for datasets-related configuration"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import ChatTemplate
|
||||||
|
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedPrompterType(BaseModel):
|
||||||
|
"""Structure for user defined prompt types"""
|
||||||
|
|
||||||
|
system_prompt: str | None = None
|
||||||
|
system_format: str | None = None
|
||||||
|
field_system: str | None = None
|
||||||
|
field_instruction: str | None = None
|
||||||
|
field_input: str | None = None
|
||||||
|
field_output: str | None = None
|
||||||
|
|
||||||
|
format: str | None = None
|
||||||
|
no_input_format: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(BaseModel):
|
||||||
|
"""SFT configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: str | UserDefinedPrompterType | None = None
|
||||||
|
input_transform: str | None = None
|
||||||
|
shards: int | None = None
|
||||||
|
shards_idx: int | None = None
|
||||||
|
preprocess_shards: int | None = None
|
||||||
|
conversation: str | None = None
|
||||||
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
|
chat_template: ChatTemplate | str | None = None
|
||||||
|
chat_template_jinja: str | None = None
|
||||||
|
data_files: str | list[str] | None = None
|
||||||
|
input_format: str | None = None
|
||||||
|
name: str | None = None
|
||||||
|
ds_type: str | None = None
|
||||||
|
train_on_split: str | None = None
|
||||||
|
field: str | None = None
|
||||||
|
field_human: str | None = None
|
||||||
|
field_model: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_role: str | None = None
|
||||||
|
# deprecated, use message_property_mappings
|
||||||
|
message_field_content: str | None = None
|
||||||
|
message_property_mappings: dict[str, str] | None = None
|
||||||
|
message_field_training: str | None = None
|
||||||
|
message_field_training_detail: str | None = None
|
||||||
|
logprobs_field: str | None = None
|
||||||
|
temperature: float | None = None
|
||||||
|
roles_to_train: list[str] | None = None
|
||||||
|
train_on_eos: str | None = None
|
||||||
|
roles: dict[str, list[str]] | None = None
|
||||||
|
drop_system_message: bool | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def handle_legacy_message_fields(cls, data):
|
||||||
|
"""Handle backwards compatibility between legacy message field mapping and new property mapping system."""
|
||||||
|
return handle_legacy_message_fields_logic(data)
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def check_chat_template_config(cls, data):
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
# Set chat_template to tokenizer_default if not set
|
||||||
|
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||||
|
|
||||||
|
# if chat_template is set to jinja, chat_template_jinja is required
|
||||||
|
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||||
|
"chat_template_jinja"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"chat_template_jinja is required when chat_template is set to jinja"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If chat_template_jinja is set, set chat_template to jinja
|
||||||
|
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||||
|
data["chat_template"] = ChatTemplate.jinja
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class PretrainingDataset(BaseModel):
|
||||||
|
"""Pretraining dataset configuration subset"""
|
||||||
|
|
||||||
|
name: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = "train"
|
||||||
|
text_column: str | None = "text"
|
||||||
|
type: str | None = "pretrain"
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
data_files: str | None = None
|
||||||
|
skip: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedDPOType(BaseModel):
|
||||||
|
"""User defined typing for DPO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_chosen: str | None = None
|
||||||
|
field_rejected: str | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
chosen_format: str | None = None
|
||||||
|
rejected_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class DPODataset(BaseModel):
|
||||||
|
"""DPO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedDPOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
field_messages: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class StepwiseSupervisedDataset(BaseModel):
|
||||||
|
"""Stepwise supervised dataset configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
revision: str | None = None
|
||||||
|
step_separator: str | None = None
|
||||||
|
max_completion_length: int | None = None
|
||||||
|
train_on_last_step_only: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class UserDefinedKTOType(BaseModel):
|
||||||
|
"""User defined typing for KTO"""
|
||||||
|
|
||||||
|
field_system: str | None = None
|
||||||
|
field_prompt: str | None = None
|
||||||
|
field_completion: str | None = None
|
||||||
|
field_label: bool | None = None
|
||||||
|
prompt_format: str | None = None
|
||||||
|
completion_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class KTODataset(BaseModel):
|
||||||
|
"""KTO configuration subset"""
|
||||||
|
|
||||||
|
path: str | None = None
|
||||||
|
split: str | None = None
|
||||||
|
type: UserDefinedKTOType | str | None = None
|
||||||
|
data_files: list[str] | None = None
|
||||||
|
trust_remote_code: bool | None = False
|
||||||
|
revision: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
|
||||||
68
src/axolotl/utils/schemas/deprecated.py
Normal file
68
src/axolotl/utils/schemas/deprecated.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""Pydantic models for deprecated and remapped configuration parameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeprecatedParameters(BaseModel):
|
||||||
|
"""configurations that are deprecated"""
|
||||||
|
|
||||||
|
max_packed_sequence_len: int | None = None
|
||||||
|
rope_scaling: Any | None = None
|
||||||
|
noisy_embedding_alpha: float | None = None
|
||||||
|
dpo_beta: float | None = None
|
||||||
|
evaluation_strategy: str | None = None
|
||||||
|
|
||||||
|
@field_validator("max_packed_sequence_len")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
|
||||||
|
if max_packed_sequence_len:
|
||||||
|
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||||
|
return max_packed_sequence_len
|
||||||
|
|
||||||
|
@field_validator("rope_scaling")
|
||||||
|
@classmethod
|
||||||
|
def validate_rope_scaling(cls, rope_scaling):
|
||||||
|
if rope_scaling:
|
||||||
|
raise DeprecationWarning(
|
||||||
|
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
|
||||||
|
)
|
||||||
|
return rope_scaling
|
||||||
|
|
||||||
|
@field_validator("noisy_embedding_alpha")
|
||||||
|
@classmethod
|
||||||
|
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
|
||||||
|
if noisy_embedding_alpha:
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
@field_validator("dpo_beta")
|
||||||
|
@classmethod
|
||||||
|
def validate_dpo_beta(cls, dpo_beta):
|
||||||
|
if dpo_beta is not None:
|
||||||
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
|
class RemappedParameters(BaseModel):
|
||||||
|
"""Parameters that have been remapped to other names"""
|
||||||
|
|
||||||
|
overrides_of_model_config: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
overrides_of_model_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None, alias="model_kwargs"
|
||||||
|
)
|
||||||
|
type_of_model: str | None = Field(default=None, alias="model_type")
|
||||||
|
revision_of_model: str | None = Field(default=None, alias="model_revision")
|
||||||
49
src/axolotl/utils/schemas/enums.py
Normal file
49
src/axolotl/utils/schemas/enums.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Enums for Axolotl input config"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class RLType(str, Enum):
|
||||||
|
"""RL trainer type configuration subset"""
|
||||||
|
|
||||||
|
dpo = "dpo" # pylint: disable=invalid-name
|
||||||
|
grpo = "grpo" # pylint: disable=invalid-name
|
||||||
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTemplate(str, Enum):
|
||||||
|
"""Chat templates configuration subset"""
|
||||||
|
|
||||||
|
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||||
|
chatml = "chatml" # pylint: disable=invalid-name
|
||||||
|
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||||
|
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||||
|
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||||
|
gemma = "gemma" # pylint: disable=invalid-name
|
||||||
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
|
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||||
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||||
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
deepseek_v3 = "deepseek_v3" # pylint: disable=invalid-name
|
||||||
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
jinja = "jinja" # pylint: disable=invalid-name
|
||||||
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class CustomSupportedOptimizers(str, Enum):
|
||||||
|
"""Custom supported optimizers"""
|
||||||
|
|
||||||
|
optimi_adamw = "optimi_adamw" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_4bit = "ao_adamw_4bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||||
|
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||||
|
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||||
|
muon = "muon" # pylint: disable=invalid-name
|
||||||
108
src/axolotl/utils/schemas/integrations.py
Normal file
108
src/axolotl/utils/schemas/integrations.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Pydantic models for Axolotl integrations"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MLFlowConfig(BaseModel):
|
||||||
|
"""MLFlow configuration subset"""
|
||||||
|
|
||||||
|
use_mlflow: bool | None = None
|
||||||
|
mlflow_tracking_uri: str | None = None
|
||||||
|
mlflow_experiment_name: str | None = None
|
||||||
|
mlflow_run_name: str | None = None
|
||||||
|
hf_mlflow_log_artifacts: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LISAConfig(BaseModel):
|
||||||
|
"""LISA configuration subset"""
|
||||||
|
|
||||||
|
lisa_n_layers: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: str | None = Field(
|
||||||
|
default="model.layers",
|
||||||
|
json_schema_extra={"description": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbConfig(BaseModel):
|
||||||
|
"""Wandb configuration subset"""
|
||||||
|
|
||||||
|
use_wandb: bool | None = None
|
||||||
|
wandb_name: str | None = None
|
||||||
|
wandb_run_id: str | None = None
|
||||||
|
wandb_mode: str | None = None
|
||||||
|
wandb_project: str | None = None
|
||||||
|
wandb_entity: str | None = None
|
||||||
|
wandb_watch: str | None = None
|
||||||
|
wandb_log_model: str | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_wandb_run(cls, data):
|
||||||
|
if data.get("wandb_run_id") and not data.get("wandb_name"):
|
||||||
|
data["wandb_name"] = data.get("wandb_run_id")
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CometConfig(BaseModel):
|
||||||
|
"""Comet configuration subset"""
|
||||||
|
|
||||||
|
use_comet: bool | None = None
|
||||||
|
comet_api_key: str | None = None
|
||||||
|
comet_workspace: str | None = None
|
||||||
|
comet_project_name: str | None = None
|
||||||
|
comet_experiment_key: str | None = None
|
||||||
|
comet_mode: str | None = None
|
||||||
|
comet_online: bool | None = None
|
||||||
|
comet_experiment_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class GradioConfig(BaseModel):
|
||||||
|
"""Gradio configuration subset"""
|
||||||
|
|
||||||
|
gradio_title: str | None = None
|
||||||
|
gradio_share: bool | None = None
|
||||||
|
gradio_server_name: str | None = None
|
||||||
|
gradio_server_port: int | None = None
|
||||||
|
gradio_max_new_tokens: int | None = None
|
||||||
|
gradio_temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RayConfig(BaseModel):
|
||||||
|
"""Ray launcher configuration subset"""
|
||||||
|
|
||||||
|
use_ray: bool = Field(default=False)
|
||||||
|
ray_run_name: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The training results will be saved at `saves/ray_run_name`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ray_num_workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The number of workers for Ray training. Default is 1 worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resources_per_worker: dict = Field(
|
||||||
|
default_factory=lambda: {"GPU": 1},
|
||||||
|
json_schema_extra={
|
||||||
|
"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."
|
||||||
|
},
|
||||||
|
)
|
||||||
55
src/axolotl/utils/schemas/model.py
Normal file
55
src/axolotl/utils/schemas/model.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Pydantic models for model input / output, etc. configuration"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputConfig(BaseModel):
|
||||||
|
"""Model configuration subset"""
|
||||||
|
|
||||||
|
model_config = {"protected_namespaces": ()}
|
||||||
|
|
||||||
|
base_model: str
|
||||||
|
base_model_config: str | None = None
|
||||||
|
cls_model_config: str | None = None
|
||||||
|
tokenizer_config: str | None = None
|
||||||
|
tokenizer_use_fast: bool | None = None
|
||||||
|
tokenizer_legacy: bool | None = None
|
||||||
|
tokenizer_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||||
|
)
|
||||||
|
processor_type: str | None = Field(
|
||||||
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
|
)
|
||||||
|
trust_remote_code: bool | None = None
|
||||||
|
|
||||||
|
@field_validator("trust_remote_code")
|
||||||
|
@classmethod
|
||||||
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
|
if trust_remote_code:
|
||||||
|
LOG.warning(
|
||||||
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||||
|
)
|
||||||
|
return trust_remote_code
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutputConfig(BaseModel):
|
||||||
|
"""model save configuration subset"""
|
||||||
|
|
||||||
|
output_dir: str = Field(default="./model-out")
|
||||||
|
hub_model_id: str | None = None
|
||||||
|
hub_strategy: str | None = None
|
||||||
|
save_safetensors: bool | None = True
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialTokensConfig(BaseModel):
|
||||||
|
"""Special tokens configuration subset"""
|
||||||
|
|
||||||
|
bos_token: str | None = None
|
||||||
|
eos_token: str | None = None
|
||||||
|
pad_token: str | None = None
|
||||||
|
unk_token: str | None = None
|
||||||
|
additional_special_tokens: list[str] | None = None
|
||||||
132
src/axolotl/utils/schemas/peft.py
Normal file
132
src/axolotl/utils/schemas/peft.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Pydantic models for PEFT-related configuration"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class LoftQConfig(BaseModel):
|
||||||
|
"""LoftQ configuration subset"""
|
||||||
|
|
||||||
|
loftq_bits: int = Field(
|
||||||
|
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||||
|
)
|
||||||
|
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||||
|
|
||||||
|
|
||||||
|
class PeftConfig(BaseModel):
|
||||||
|
"""peftq configuration subset"""
|
||||||
|
|
||||||
|
loftq_config: LoftQConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LoraConfig(BaseModel):
|
||||||
|
"""Peft / LoRA configuration subset"""
|
||||||
|
|
||||||
|
load_in_8bit: bool | None = Field(default=False)
|
||||||
|
load_in_4bit: bool | None = Field(default=False)
|
||||||
|
|
||||||
|
adapter: str | None = None
|
||||||
|
lora_model_dir: str | None = None
|
||||||
|
lora_r: int | None = None
|
||||||
|
lora_alpha: int | None = None
|
||||||
|
lora_fan_in_fan_out: bool | None = None
|
||||||
|
lora_target_modules: str | list[str] | None = None
|
||||||
|
lora_target_linear: bool | None = None
|
||||||
|
lora_modules_to_save: list[str] | None = None
|
||||||
|
lora_dropout: float | None = 0.0
|
||||||
|
peft_layers_to_transform: list[int] | None = None
|
||||||
|
peft_layers_pattern: list[str] | None = None
|
||||||
|
peft: PeftConfig | None = None
|
||||||
|
peft_use_dora: bool | None = None
|
||||||
|
peft_use_rslora: bool | None = None
|
||||||
|
peft_layer_replication: list[tuple[int, int]] | None = None
|
||||||
|
peft_init_lora_weights: bool | str | None = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: bool | None = Field(
|
||||||
|
default=False,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
lora_on_cpu: bool | None = None
|
||||||
|
gptq: bool | None = None
|
||||||
|
bnb_config_kwargs: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
loraplus_lr_ratio: float | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: float | None = Field(
|
||||||
|
default=1e-6,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "loraplus learning rate for lora embedding layers."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
merge_lora: bool | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_adapter(cls, data):
|
||||||
|
if (
|
||||||
|
not data.get("adapter")
|
||||||
|
and not data.get("inference")
|
||||||
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_qlora(self):
|
||||||
|
if self.adapter == "qlora":
|
||||||
|
if self.merge_lora:
|
||||||
|
# can't merge qlora if loaded in 8bit or 4bit
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't merge qlora if gptq")
|
||||||
|
|
||||||
|
if self.load_in_4bit:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.load_in_8bit:
|
||||||
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
|
if self.gptq:
|
||||||
|
raise ValueError("Can't load qlora if gptq")
|
||||||
|
|
||||||
|
if not self.load_in_4bit:
|
||||||
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
@classmethod
|
||||||
|
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
|
||||||
|
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
|
||||||
|
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||||
|
return loraplus_lr_embedding
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_lora_dropout(cls, data):
|
||||||
|
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||||
|
data["lora_dropout"] = 0.0
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRAConfig(BaseModel):
|
||||||
|
"""ReLoRA configuration subset"""
|
||||||
|
|
||||||
|
relora_steps: int | None = None
|
||||||
|
relora_warmup_steps: int | None = None
|
||||||
|
relora_anneal_steps: int | None = None
|
||||||
|
relora_prune_ratio: float | None = None
|
||||||
|
relora_cpu_offload: bool | None = None
|
||||||
99
src/axolotl/utils/schemas/training.py
Normal file
99
src/axolotl/utils/schemas/training.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Pydantic models for training hyperparameters"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from transformers import SchedulerType
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LrGroup(BaseModel):
|
||||||
|
"""Custom learning rate group configuration"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
modules: list[str]
|
||||||
|
lr: float
|
||||||
|
|
||||||
|
|
||||||
|
class HyperparametersConfig(BaseModel):
|
||||||
|
"""Training hyperparams configuration subset"""
|
||||||
|
|
||||||
|
gradient_accumulation_steps: int | None = Field(default=1)
|
||||||
|
micro_batch_size: int | None = Field(
|
||||||
|
default=1,
|
||||||
|
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||||
|
)
|
||||||
|
batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Total batch size, we do not recommended setting this manually"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
eval_batch_size: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_find_batch_size: bool | None = None
|
||||||
|
|
||||||
|
train_on_inputs: bool | None = False
|
||||||
|
group_by_length: bool | None = None
|
||||||
|
|
||||||
|
learning_rate: str | float
|
||||||
|
embedding_lr: float | None = None
|
||||||
|
embedding_lr_scale: float | None = None
|
||||||
|
weight_decay: float | None = 0.0
|
||||||
|
optimizer: (OptimizerNames | CustomSupportedOptimizers) | None = (
|
||||||
|
OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
|
)
|
||||||
|
optim_args: (str | dict[str, Any]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
|
)
|
||||||
|
optim_target_modules: (list[str] | Literal["all_linear"]) | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
torchdistx_path: str | None = None
|
||||||
|
lr_scheduler: (SchedulerType | Literal["one_cycle"] | Literal["rex"]) | None = (
|
||||||
|
SchedulerType.COSINE
|
||||||
|
)
|
||||||
|
lr_scheduler_kwargs: dict[str, Any] | None = None
|
||||||
|
lr_quadratic_warmup: bool | None = None
|
||||||
|
cosine_min_lr_ratio: float | None = None
|
||||||
|
cosine_constant_lr_ratio: float | None = None
|
||||||
|
lr_div_factor: float | None = None
|
||||||
|
lr_groups: list[LrGroup] | None = None
|
||||||
|
|
||||||
|
adam_epsilon: float | None = None
|
||||||
|
adam_beta1: float | None = None
|
||||||
|
adam_beta2: float | None = None
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
num_epochs: float = Field(default=1.0)
|
||||||
|
|
||||||
|
@field_validator("batch_size")
|
||||||
|
@classmethod
|
||||||
|
def hint_batch_size_set(cls, batch_size):
|
||||||
|
if batch_size:
|
||||||
|
LOG.warning(
|
||||||
|
"%s\n%s",
|
||||||
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||||
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||||
|
)
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
@field_validator("learning_rate")
|
||||||
|
@classmethod
|
||||||
|
def convert_learning_rate(cls, learning_rate):
|
||||||
|
if learning_rate and isinstance(learning_rate, str):
|
||||||
|
learning_rate = float(learning_rate)
|
||||||
|
return learning_rate
|
||||||
@@ -1,8 +1,4 @@
|
|||||||
"""
|
"""Pydantic models for TRL trainer configuration"""
|
||||||
GRPO specific configuration args
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -12,11 +8,11 @@ class TRLConfig(BaseModel):
|
|||||||
Input args for TRL.
|
Input args for TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
beta: Optional[float] = Field(
|
beta: float | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Beta for RL training"},
|
json_schema_extra={"description": "Beta for RL training"},
|
||||||
)
|
)
|
||||||
max_completion_length: Optional[int] = Field(
|
max_completion_length: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the completion for RL training"
|
"description": "Maximum length of the completion for RL training"
|
||||||
@@ -25,50 +21,50 @@ class TRLConfig(BaseModel):
|
|||||||
|
|
||||||
# GRPO specific args
|
# GRPO specific args
|
||||||
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
# Ref: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/grpo_config.py#L22
|
||||||
use_vllm: Optional[bool] = Field(
|
use_vllm: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
json_schema_extra={"description": "Whether to use VLLM for RL training"},
|
||||||
)
|
)
|
||||||
vllm_device: Optional[str] = Field(
|
vllm_device: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Device to use for VLLM"},
|
json_schema_extra={"description": "Device to use for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_gpu_memory_utilization: Optional[float] = Field(
|
vllm_gpu_memory_utilization: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
json_schema_extra={"description": "GPU memory utilization for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_dtype: Optional[str] = Field(
|
vllm_dtype: str | None = Field(
|
||||||
default="auto",
|
default="auto",
|
||||||
json_schema_extra={"description": "Data type for VLLM"},
|
json_schema_extra={"description": "Data type for VLLM"},
|
||||||
)
|
)
|
||||||
vllm_max_model_len: Optional[int] = Field(
|
vllm_max_model_len: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Maximum length of the model context for VLLM"
|
"description": "Maximum length of the model context for VLLM"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
reward_funcs: Optional[list[str]] = Field(
|
reward_funcs: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "List of reward functions to load"},
|
json_schema_extra={"description": "List of reward functions to load"},
|
||||||
)
|
)
|
||||||
reward_weights: Optional[list[float]] = Field(
|
reward_weights: list[float] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Weights for each reward function. Must match the number of reward functions."
|
"description": "Weights for each reward function. Must match the number of reward functions."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
num_generations: Optional[int] = Field(
|
num_generations: int | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
"description": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) must be divisible by this value."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
log_completions: Optional[bool] = Field(
|
log_completions: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={"description": "Whether to log completions"},
|
json_schema_extra={"description": "Whether to log completions"},
|
||||||
)
|
)
|
||||||
sync_ref_model: Optional[bool] = Field(
|
sync_ref_model: bool | None = Field(
|
||||||
default=False,
|
default=False,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": (
|
"description": (
|
||||||
@@ -77,13 +73,13 @@ class TRLConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_mixup_alpha: Optional[float] = Field(
|
ref_model_mixup_alpha: float | None = Field(
|
||||||
default=0.9,
|
default=0.9,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
"description": "Mixup alpha for the reference model. Requires `sync_ref_model=True`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
ref_model_sync_steps: Optional[int] = Field(
|
ref_model_sync_steps: int | None = Field(
|
||||||
default=64,
|
default=64,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
"description": "Sync steps for the reference model. Requires `sync_ref_model=True`."
|
||||||
79
src/axolotl/utils/schemas/utils.py
Normal file
79
src/axolotl/utils/schemas/utils.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Utilities for Axolotl Pydantic models"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Handle backwards compatibility between legacy message field mapping and new property mapping system.
|
||||||
|
|
||||||
|
Previously, the config only supported mapping 'role' and 'content' fields via dedicated config options:
|
||||||
|
- message_field_role: Mapped to the role field
|
||||||
|
- message_field_content: Mapped to the content field
|
||||||
|
|
||||||
|
The new system uses message_property_mappings to support arbitrary field mappings:
|
||||||
|
message_property_mappings:
|
||||||
|
role: source_role_field
|
||||||
|
content: source_content_field
|
||||||
|
additional_field: source_field
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing configuration data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated dictionary with message field mappings consolidated
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there are conflicts between legacy and new mappings
|
||||||
|
"""
|
||||||
|
data = data.copy() # Create a copy to avoid modifying the original
|
||||||
|
|
||||||
|
if data.get("message_property_mappings") is None:
|
||||||
|
data["message_property_mappings"] = {}
|
||||||
|
|
||||||
|
# Check for conflicts and handle role
|
||||||
|
if "message_field_role" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_role is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{role: {data['message_field_role']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"role" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["role"] != data["message_field_role"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message role fields: message_field_role='{data['message_field_role']}' "
|
||||||
|
f"conflicts with message_property_mappings.role='{data['message_property_mappings']['role']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["role"] = data["message_field_role"] or "role"
|
||||||
|
|
||||||
|
del data["message_field_role"]
|
||||||
|
elif "role" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["role"] = "role"
|
||||||
|
|
||||||
|
# Check for conflicts and handle content
|
||||||
|
if "message_field_content" in data:
|
||||||
|
LOG.warning(
|
||||||
|
"message_field_content is deprecated, use message_property_mappings instead. "
|
||||||
|
f"Example: message_property_mappings: {{content: {data['message_field_content']}}}"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"content" in data["message_property_mappings"]
|
||||||
|
and data["message_property_mappings"]["content"]
|
||||||
|
!= data["message_field_content"]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Conflicting message content fields: message_field_content='{data['message_field_content']}' "
|
||||||
|
f"conflicts with message_property_mappings.content='{data['message_property_mappings']['content']}'"
|
||||||
|
)
|
||||||
|
data["message_property_mappings"]["content"] = (
|
||||||
|
data["message_field_content"] or "content"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data["message_field_content"]
|
||||||
|
elif "content" not in data["message_property_mappings"]:
|
||||||
|
data["message_property_mappings"]["content"] = "content"
|
||||||
|
|
||||||
|
return data
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
h1 {
|
h1 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 5rem;
|
font-size: 3rem;
|
||||||
line-height: 1.1;
|
line-height: 1.1;
|
||||||
letter-spacing: -0.05em;
|
letter-spacing: -0.05em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -24,7 +24,7 @@ h1 {
|
|||||||
h2 {
|
h2 {
|
||||||
font-family: var(--font-title);
|
font-family: var(--font-title);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
font-size: 2rem;
|
font-size: 1.5rem;
|
||||||
line-height: 1.2;
|
line-height: 1.2;
|
||||||
letter-spacing: -0.03em;
|
letter-spacing: -0.03em;
|
||||||
font-feature-settings: "ss01" on;
|
font-feature-settings: "ss01" on;
|
||||||
@@ -35,7 +35,7 @@ h3,
|
|||||||
h4 {
|
h4 {
|
||||||
font-family: var(--font-body);
|
font-family: var(--font-body);
|
||||||
font-weight: 400;
|
font-weight: 400;
|
||||||
font-size: 1.5rem;
|
font-size: 1.25rem;
|
||||||
line-height: 1.5;
|
line-height: 1.5;
|
||||||
letter-spacing: -0.02em;
|
letter-spacing: -0.02em;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,10 +11,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from axolotl.utils import is_comet_available
|
from axolotl.utils import is_comet_available
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
|
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.config import validate_config
|
from axolotl.utils.config import validate_config
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.schemas.datasets import ChatTemplate
|
||||||
|
|
||||||
warnings.filterwarnings("error")
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class TestModelsUtils:
|
|||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
"""Test message property mapping configuration validation"""
|
"""Test message property mapping configuration validation"""
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import SFTDataset
|
from axolotl.utils.schemas.datasets import SFTDataset
|
||||||
|
|
||||||
# Test legacy fields are mapped orrectly
|
# Test legacy fields are mapped orrectly
|
||||||
dataset = SFTDataset(
|
dataset = SFTDataset(
|
||||||
|
|||||||
Reference in New Issue
Block a user