Pydantic 2.x cfg (#1239)

* WIP conversion to use pydantic for config validation

* wip, more fields, add capabilities

* wip

* update pydantic validation to match existing tests

* tweak requirements

* setup deprecated paams pydantic model

* more validations

* wrap up rest of the validations

* flesh out the rest of the options from the readme into pydantic

* fix model validators as class methods

remember to return in validator
missing return
add missing relora attributes
fix test for DictDefault change
fix sys template for mistral from fastchat change in PR 2872
fix test for batch size warning

* more missing attributes for cfg

* updates from PR feedback

* fix validation for datasets and pretrain datasets

* fix test for lora check
This commit is contained in:
Wing Lian
2024-02-26 12:24:14 -05:00
committed by GitHub
parent 5894f0e57e
commit cc3cebfa70
16 changed files with 1710 additions and 410 deletions

View File

@@ -1,5 +1,5 @@
[mypy] [mypy]
plugins = pydantic.mypy
exclude = venv exclude = venv
[mypy-alpaca_lora_4bit.*] [mypy-alpaca_lora_4bit.*]

View File

@@ -31,6 +31,7 @@ repos:
additional_dependencies: additional_dependencies:
[ [
'types-PyYAML', 'types-PyYAML',
'pydantic>=2.5.3',
] ]
- repo: https://github.com/PyCQA/bandit - repo: https://github.com/PyCQA/bandit
rev: 1.7.5 rev: 1.7.5

View File

@@ -543,7 +543,7 @@ is_mistral_derived_model:
is_qwen_derived_model: is_qwen_derived_model:
# optional overrides to the base model configuration # optional overrides to the base model configuration
model_config: model_config_overrides:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic
@@ -560,8 +560,6 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
@@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
# For one_cycle optim # For one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer # Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see: # Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134

View File

@@ -6,6 +6,7 @@ tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate==0.26.1 accelerate==0.26.1
deepspeed>=0.13.1 deepspeed>=0.13.1
pydantic>=2.5.3
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
@@ -27,7 +28,7 @@ scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.34 fschat==0.2.36
gradio==3.50.2 gradio==3.50.2
tensorboard tensorboard

View File

@@ -24,11 +24,13 @@ from art import text2art
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import ( from axolotl.utils.config import (
GPUCapabilities,
normalize_cfg_datasets, normalize_cfg_datasets,
normalize_config, normalize_config,
validate_config, validate_config,
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
# load the config from the yaml file # load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
@@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else: else:
cfg[k] = kwargs[k] cfg[k] = kwargs[k]
validate_config(cfg) cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
capabilities = GPUCapabilities(
bf16=is_torch_bf16_gpu_available(),
n_gpu=os.environ.get("WORLD_SIZE", 1),
compute_capability=gpu_version,
)
cfg = validate_config(cfg, capabilities=capabilities)
prepare_optim_env(cfg) prepare_optim_env(cfg)

View File

@@ -3,11 +3,17 @@ import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config
@@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml" cfg.datasets[idx].conversation = "chatml"
def validate_config(cfg): def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
if capabilities:
return DictDefault(
dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
)
return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict())))
def legacy_validate_config(cfg):
""" """
This is a "pre-validation" step that handles the yaml configuration before we have any This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture information about the model architecture
@@ -480,9 +494,6 @@ def validate_config(cfg):
if cfg.rope_scaling: if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`") LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.wandb_run_id and not cfg.wandb_name: if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id cfg.wandb_name = cfg.wandb_run_id

View File

@@ -0,0 +1,931 @@
"""
Module for pydantic models for configuration
"""
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from axolotl.utils.config.models.internals import GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
max_packed_sequence_len: Optional[int] = None
rope_scaling: Optional[Any] = None
noisy_embedding_alpha: Optional[float] = None
@field_validator("max_packed_sequence_len")
@classmethod
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
if max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
return max_packed_sequence_len
@field_validator("rope_scaling")
@classmethod
def validate_rope_scaling(cls, rope_scaling):
if rope_scaling:
raise DeprecationWarning(
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
)
return rope_scaling
@field_validator("noisy_embedding_alpha")
@classmethod
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
if noisy_embedding_alpha:
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
return noisy_embedding_alpha
class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""
path: Optional[str] = None
class UserDefinedPrompterType(BaseModel):
"""structure for user defined prompt types"""
system_prompt: Optional[str] = None
system_format: Optional[str] = None
field_system: Optional[str] = None
field_instruction: Optional[str] = None
field_input: Optional[str] = None
field_output: Optional[str] = None
format: Optional[str] = None
no_input_format: Optional[str] = None
field: Optional[str] = None
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None
shards: Optional[int] = None
conversation: Optional[str] = None
data_files: Optional[List[str]] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
class DPODataset(BaseModel):
"""DPO configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
type: Optional[str] = None
data_files: Optional[List[str]] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: Optional[LoftQConfig] = None
class AutoType(str, Enum):
"""auto type string configuration subset - used for bf16"""
AUTO = "auto"
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""
bos_token: Optional[str] = None
eos_token: Optional[str] = None
pad_token: Optional[str] = None
unk_token: Optional[str] = None
additional_special_tokens: Optional[List[str]] = None
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: Optional[bool] = Field(default=False)
load_in_4bit: Optional[bool] = Field(default=False)
adapter: Optional[str] = None
lora_model_dir: Optional[str] = None
lora_rank: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
lora_target_modules: Optional[List[str]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = None
peft_layers_to_transform: Optional[List[int]] = None
peft: Optional[PeftConfig] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
merge_lora: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if not data.get("adapter") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: Optional[int] = None
relora_warmup_steps: Optional[int] = None
relora_anneal_steps: Optional[int] = None
relora_prune_ratio: Optional[float] = None
relora_cpu_offload: Optional[bool] = None
class ModelInputConfig(BaseModel):
"""model to train on configuration subset"""
base_model: str
base_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
model_type: Optional[str] = Field(default=None)
model_revision: Optional[str] = None
trust_remote_code: Optional[bool] = None
model_config_overrides: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
if trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
return trust_remote_code
class HyperparametersConfig(BaseModel):
"""training hyperparams configuration subset"""
gradient_accumulation_steps: Optional[int] = Field(default=1)
micro_batch_size: Optional[int] = Field(
default=1,
metadata={"help": "per gpu micro batch size for training"},
)
batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
train_on_inputs: Optional[bool] = None
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[OptimizerNames] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
adam_beta2: Optional[float] = None
max_grad_norm: Optional[float] = None
num_epochs: int = Field(default=1)
@field_validator("batch_size")
@classmethod
def hint_batch_size_set(cls, batch_size):
if batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
return batch_size
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
output_dir: str = Field(default="./model-out")
hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None
class MLFlowConfig(BaseModel):
"""mlflow configuration subset"""
use_mlflow: Optional[str] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
class WandbConfig(BaseModel):
"""wandb configuration subset"""
use_wandb: Optional[bool] = None
wandb_name: Optional[str] = None
wandb_run_id: Optional[str] = None
wandb_mode: Optional[str] = None
wandb_project: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_watch: Optional[str] = None
wandb_log_model: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_wandb_run(cls, data):
if data.get("wandb_run_id") and not data.get("wandb_name"):
data["wandb_name"] = data.get("wandb_run_id")
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
return data
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
LoraConfig,
ReLoRAConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
DeprecatedParameters,
BaseModel,
):
"""wrapper of all config options"""
strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
rl: Optional[RLType] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
pretraining_dataset: Optional[ # type: ignore
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None
hf_use_auth_token: Optional[bool] = None
device: Optional[Any] = None
device_map: Optional[Any] = None
world_size: Optional[int] = None
local_rank: Optional[int] = None
ddp: Optional[bool] = None
seed: Optional[int] = None
ddp_timeout: Optional[int] = None
ddp_bucket_cap_mb: Optional[int] = None
ddp_broadcast_buffers: Optional[bool] = None
ddp_find_unused_parameters: Optional[bool] = None
eval_table_size: Optional[int] = None
eval_max_new_tokens: Optional[int] = None
do_causal_lm_eval: Optional[bool] = None
eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None
metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None
bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
float16: Optional[bool] = None # for non-AMP cases
tf32: Optional[bool] = None
float32: Optional[bool] = None
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=1024)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
flash_attn_fuse_qkv: Optional[bool] = None
flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
val_set_size: Optional[float] = Field(default=0.0)
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None
eval_steps: Optional[int] = None
evaluation_strategy: Optional[str] = None
save_steps: Optional[int] = None
saves_per_epoch: Optional[int] = None
save_strategy: Optional[str] = None
save_total_limit: Optional[int] = None
logging_steps: Optional[int] = None
early_stopping_patience: Optional[int] = None
neftune_noise_alpha: Optional[float] = None
max_memory: Optional[Union[int, str]] = None
gpu_memory_limit: Optional[Union[int, str]] = None
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
default_system_message: Optional[str] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None
sample_packing_eff_est: Optional[float] = None
axolotl_config_path: Optional[str] = None
is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)
@field_validator("datasets", mode="before")
@classmethod
def fix_sharegpt_datasets(cls, datasets):
for idx, ds_cfg in enumerate(datasets):
if not ds_cfg["type"]:
continue
if ds_cfg["type"] == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = "sharegpt"
if "sharegpt_simple" in ds_cfg["type"]:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = datasets[idx]["type"].replace(
"sharegpt_simple", "sharegpt"
)
return datasets
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):
fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count < 2:
raise ValueError(f"At least two of {', '.join(fields)} must be set")
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_w_max_steps(cls, data):
if data.get("pretraining_dataset") and not data.get("max_steps"):
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_w_group_by_length(cls, data):
if data.get("pretraining_dataset") and data.get("group_by_length"):
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
return data
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
if data.get("gptq") and data.get("model_revision"):
raise ValueError(
"model_revision is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):
if data.get("sample_packing") and data.get("rl"):
raise ValueError("`sample_packing: true` does not work with RLHF training")
return data
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
return data
@model_validator(mode="before")
@classmethod
def hint_eval_train_mbsz(cls, data):
if (
data.get("eval_batch_size")
and data.get("micro_batch_size")
and data.get("eval_batch_size") != data.get("micro_batch_size")
):
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_ds_auth(cls, data):
if (
data.get("push_dataset_to_hub")
and data.get("hf_use_auth_token") is not True
):
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
return data
@model_validator(mode="after")
def check_falcon_fsdp(self):
if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp:
raise ValueError("FSDP is not supported for falcon models")
return self
@model_validator(mode="after")
def check_mpt_checkpointing(self):
if (
self.base_model and "mpt" in self.base_model.lower()
) and self.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
if self.adapter:
LOG.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if self.fp16 or self.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if self.float16 is not True and self.bfloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
return self
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):
if (
data.get("save_strategy")
and data.get("save_steps")
and data.get("save_strategy") != "steps"
):
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if data.get("saves_per_epoch") and data.get("save_steps"):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals(cls, data):
if (
data.get("evaluation_strategy")
and data.get("eval_steps")
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("evaluation_strategy"))
and not data.get("test_datasets")
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
data.get("evals_per_epoch")
and data.get("evaluation_strategy")
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
return data
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
if (
data.get("sample_packing")
and data.get("eval_table_size")
and data.get("eval_sample_packing") is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
return data
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
if data.get("warmup_steps") and data.get("warmup_ratio"):
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
return data
@model_validator(mode="before")
@classmethod
def check_neftune(cls, data):
if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
data["neftune_noise_alpha"] = data["noisy_embedding_alpha"]
del data["noisy_embedding_alpha"]
elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
return data
@field_validator("neftune_noise_alpha")
@classmethod
def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
return neftune_noise_alpha
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
if (
data.get("adapter")
and data.get("peft_layers_to_transform")
and data.get("unfrozen_parameters")
):
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
# pylint: disable=too-many-boolean-expressions
not (self.bf16 or self.bfloat16)
and (self.fp16 or self.float16)
and not self.adapter
and not self.flash_attention
and self.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
return self
@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
@model_validator(mode="after")
def hint_lora_8bit(self):
loftq = (
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
)
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
return self
@model_validator(mode="after")
def check_early_stopping(self):
if self.early_stopping_patience:
if not self.save_steps or self.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if self.save_steps % self.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
return self
@model_validator(mode="after")
def check_relora(self):
if self.relora_steps:
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if self.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if self.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if self.lr_scheduler == "one_cycle":
raise ValueError(
"ReLoRA is not compatible with the one_cycle scheduler"
)
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
return self
@model_validator(mode="before")
@classmethod
def check_mem_mismatch(cls, data):
if (
data.get("max_memory") is not None
and data.get("gpu_memory_limit") is not None
):
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
if (
data.get("unfrozen_parameters")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
return data
@model_validator(mode="before")
@classmethod
def check_val_w_test_datasets(cls, data):
if data.get("test_datasets") and data.get("val_set_size"):
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_w_8bit_optimizer(cls, data):
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
return data
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if data.get("eval_causal_lm_metrics"):
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
return data
@model_validator(mode="before")
@classmethod
def check_dataset_or_pretraining_dataset(cls, data):
if data.get("datasets") is None and data.get("pretraining_dataset") is None:
raise ValueError("either datasets or pretraining_dataset is required")
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities
@model_validator(mode="after")
def check_bf16(self):
if self.capabilities.bf16:
if not self.bf16 and not self.bfloat16:
LOG.info(
"bf16 support detected, but not enabled for this configuration."
)
else:
if (
not self.merge_lora
and not self.is_preprocess
and (self.bf16 is True or self.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
return data

View File

@@ -0,0 +1,14 @@
"""module for gpu capabilities"""
from typing import Optional
from pydantic import BaseModel, Field
class GPUCapabilities(BaseModel):
"""model to manage the gpu capabilities statically"""
bf16: bool = Field(default=False)
fp8: bool = Field(default=False)
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)

View File

@@ -12,4 +12,4 @@ class DictDefault(Dict):
return None return None
def __or__(self, other): def __or__(self, other):
return DictDefault(super().__or__(other)) return DictDefault(super().__ror__(other))

View File

@@ -104,8 +104,8 @@ def load_model_config(cfg):
) )
raise err raise err
if cfg.model_config: if cfg.model_config_overrides:
for key, val in cfg.model_config.items(): for key, val in cfg.model_config_overrides.items():
setattr(model_config, key, val) setattr(model_config, key, val)
check_model_config(cfg, model_config) check_model_config(cfg, model_config)

View File

@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
), "DictDefault should support in operator for existing keys in list" ), "DictDefault should support in operator for existing keys in list"
def test_dict_or_operator(self): def test_dict_or_operator(self):
cfg = DictDefault( cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"key_a": {"key_b": "value_a"}, "key_a": {"key_b": "value_a"},
"key_c": "value_c", "key_c": "value_c",
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
} }
) )
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
)
assert ( assert (
cfg.key_a.key_b == "value_b" cfg.key_a.key_b == "value_b"
), "DictDefault should support OR operator for existing nested keys" ), "DictDefault should support OR operator for existing nested keys"

View File

@@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# fmt: off # fmt: off
# System message, multi-turn conversations # System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys']) mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>' assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2] assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# System message, single-turn conversations # System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys']) st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>' assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2] assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn # No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys']) ns_ids = tokenize(test_data['single_turn_no_sys'])

File diff suppressed because it is too large Load Diff