diff --git a/.mypy.ini b/.mypy.ini
index bb9a21c65..ede9fef88 100644
--- a/.mypy.ini
+++ b/.mypy.ini
@@ -1,5 +1,5 @@
[mypy]
-
+plugins = pydantic.mypy
exclude = venv
[mypy-alpaca_lora_4bit.*]
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c811a6eb3..6c5f20589 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -31,6 +31,7 @@ repos:
additional_dependencies:
[
'types-PyYAML',
+ 'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
diff --git a/README.md b/README.md
index bf10d1ec8..1179529a0 100644
--- a/README.md
+++ b/README.md
@@ -543,7 +543,7 @@ is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
-model_config:
+model_config_overrides:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
@@ -560,8 +560,6 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
-gptq_groupsize: 128 # group size
-gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
@@ -819,10 +817,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
# For one_cycle optim
lr_div_factor: # Learning rate div factor
-# For log_sweep optim
-log_sweep_min_lr:
-log_sweep_max_lr:
-
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
diff --git a/requirements.txt b/requirements.txt
index 6532d3999..8dce6daa7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,6 +6,7 @@ tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
deepspeed>=0.13.1
+pydantic>=2.5.3
addict
fire
PyYAML>=6.0
@@ -27,7 +28,7 @@ scipy
scikit-learn==1.2.2
pynvml
art
-fschat==0.2.34
+fschat==0.2.36
gradio==3.50.2
tensorboard
diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py
index 6b3894cb5..a15634247 100644
--- a/src/axolotl/cli/__init__.py
+++ b/src/axolotl/cli/__init__.py
@@ -24,11 +24,13 @@ from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
+from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import (
+ GPUCapabilities,
normalize_cfg_datasets,
normalize_config,
validate_config,
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
- cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
- validate_config(cfg)
+ cfg.axolotl_config_path = config
+
+ try:
+ device_props = torch.cuda.get_device_properties("cuda")
+ gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
+ except: # pylint: disable=bare-except # noqa: E722
+ gpu_version = None
+
+ capabilities = GPUCapabilities(
+ bf16=is_torch_bf16_gpu_available(),
+ n_gpu=os.environ.get("WORLD_SIZE", 1),
+ compute_capability=gpu_version,
+ )
+
+ cfg = validate_config(cfg, capabilities=capabilities)
prepare_optim_env(cfg)
diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config/__init__.py
similarity index 97%
rename from src/axolotl/utils/config.py
rename to src/axolotl/utils/config/__init__.py
index 1fc470da9..b21db3176 100644
--- a/src/axolotl/utils/config.py
+++ b/src/axolotl/utils/config/__init__.py
@@ -3,11 +3,17 @@ import json
import logging
import os
from pathlib import Path
+from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage
+from axolotl.utils.config.models.input.v0_4_1 import (
+ AxolotlConfigWCapabilities,
+ AxolotlInputConfig,
+)
+from axolotl.utils.config.models.internals import GPUCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -191,7 +197,15 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml"
-def validate_config(cfg):
+def validate_config(cfg: DictDefault, capabilities: Optional[GPUCapabilities] = None):
+ if capabilities:
+ return DictDefault(
+ dict(AxolotlConfigWCapabilities(**cfg.to_dict(), capabilities=capabilities))
+ )
+ return DictDefault(dict(AxolotlInputConfig(**cfg.to_dict())))
+
+
+def legacy_validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
@@ -480,9 +494,6 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
- if cfg.warmup_steps and cfg.warmup_ratio:
- raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
-
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
diff --git a/src/axolotl/utils/config/models/__init__.py b/src/axolotl/utils/config/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/utils/config/models/input/__init__.py b/src/axolotl/utils/config/models/input/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/utils/config/models/input/next/__init__.py b/src/axolotl/utils/config/models/input/next/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
new file mode 100644
index 000000000..433c84af1
--- /dev/null
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -0,0 +1,931 @@
+"""
+Module for pydantic models for configuration
+"""
+
+import logging
+import os
+from enum import Enum
+from typing import Any, Dict, List, Literal, Optional, Union
+
+from pydantic import BaseModel, Field, conlist, field_validator, model_validator
+from transformers import SchedulerType
+from transformers.training_args import OptimizerNames
+
+from axolotl.utils.config.models.internals import GPUCapabilities
+
+LOG = logging.getLogger("axolotl.utils.config.models.input")
+
+
+class DeprecatedParameters(BaseModel):
+ """configurations that are deprecated"""
+
+ max_packed_sequence_len: Optional[int] = None
+ rope_scaling: Optional[Any] = None
+ noisy_embedding_alpha: Optional[float] = None
+
+ @field_validator("max_packed_sequence_len")
+ @classmethod
+ def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
+ if max_packed_sequence_len:
+ raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
+ return max_packed_sequence_len
+
+ @field_validator("rope_scaling")
+ @classmethod
+ def validate_rope_scaling(cls, rope_scaling):
+ if rope_scaling:
+ raise DeprecationWarning(
+ "`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
+ )
+ return rope_scaling
+
+ @field_validator("noisy_embedding_alpha")
+ @classmethod
+ def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
+ if noisy_embedding_alpha:
+ LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
+ return noisy_embedding_alpha
+
+
+class PretrainingDataset(BaseModel):
+ """pretraining dataset configuration subset"""
+
+ path: Optional[str] = None
+
+
+class UserDefinedPrompterType(BaseModel):
+ """structure for user defined prompt types"""
+
+ system_prompt: Optional[str] = None
+ system_format: Optional[str] = None
+ field_system: Optional[str] = None
+ field_instruction: Optional[str] = None
+ field_input: Optional[str] = None
+ field_output: Optional[str] = None
+
+ format: Optional[str] = None
+ no_input_format: Optional[str] = None
+ field: Optional[str] = None
+
+
+class SFTDataset(BaseModel):
+ """SFT configuration subset"""
+
+ path: Optional[str] = None
+ split: Optional[str] = None
+ type: Optional[Union[str, UserDefinedPrompterType]] = None
+ shards: Optional[int] = None
+ conversation: Optional[str] = None
+ data_files: Optional[List[str]] = None
+ name: Optional[str] = None
+ ds_type: Optional[str] = None
+ train_on_split: Optional[str] = None
+
+ field_human: Optional[str] = None
+ field_model: Optional[str] = None
+
+
+class DPODataset(BaseModel):
+ """DPO configuration subset"""
+
+ path: Optional[str] = None
+ split: Optional[str] = None
+ type: Optional[str] = None
+ data_files: Optional[List[str]] = None
+
+
+class RLType(str, Enum):
+ """RL trainer type configuration subset"""
+
+ dpo = "dpo" # pylint: disable=invalid-name
+ ipo = "ipo" # pylint: disable=invalid-name
+ kto_pair = "kto_pair" # pylint: disable=invalid-name
+
+
+class ChatTemplate(str, Enum):
+ """Chat templates configuration subset"""
+
+ chatml = "chatml" # pylint: disable=invalid-name
+ inst = "inst" # pylint: disable=invalid-name
+
+
+class LoftQConfig(BaseModel):
+ """LoftQ configuration subset"""
+
+ loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
+ # loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
+
+
+class PeftConfig(BaseModel):
+ """peftq configuration subset"""
+
+ loftq_config: Optional[LoftQConfig] = None
+
+
+class AutoType(str, Enum):
+ """auto type string configuration subset - used for bf16"""
+
+ AUTO = "auto"
+
+
+class SpecialTokensConfig(BaseModel):
+ """Special tokens configuration subset"""
+
+ bos_token: Optional[str] = None
+ eos_token: Optional[str] = None
+ pad_token: Optional[str] = None
+ unk_token: Optional[str] = None
+ additional_special_tokens: Optional[List[str]] = None
+
+
+class LoraConfig(BaseModel):
+ """Peft / LoRA configuration subset"""
+
+ load_in_8bit: Optional[bool] = Field(default=False)
+ load_in_4bit: Optional[bool] = Field(default=False)
+
+ adapter: Optional[str] = None
+ lora_model_dir: Optional[str] = None
+ lora_rank: Optional[int] = None
+ lora_alpha: Optional[int] = None
+ lora_fan_in_fan_out: Optional[bool] = None
+ lora_target_modules: Optional[List[str]] = None
+ lora_target_linear: Optional[bool] = None
+ lora_modules_to_save: Optional[List[str]] = None
+ lora_dropout: Optional[float] = None
+ peft_layers_to_transform: Optional[List[int]] = None
+ peft: Optional[PeftConfig] = None
+
+ lora_on_cpu: Optional[bool] = None
+ gptq: Optional[bool] = None
+ bnb_config_kwargs: Optional[Dict[str, Any]] = None
+
+ merge_lora: Optional[bool] = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_adapter(cls, data):
+ if not data.get("adapter") and (
+ data.get("load_in_8bit") or data.get("load_in_4bit")
+ ):
+ raise ValueError(
+ "load_in_8bit and load_in_4bit are not supported without setting an adapter."
+ "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
+ )
+ return data
+
+ @model_validator(mode="after")
+ def validate_qlora(self):
+ if self.adapter == "qlora":
+ if self.merge_lora:
+ # can't merge qlora if loaded in 8bit or 4bit
+ if self.load_in_8bit:
+ raise ValueError("Can't merge qlora if loaded in 8bit")
+
+ if self.gptq:
+ raise ValueError("Can't merge qlora if gptq")
+
+ if self.load_in_4bit:
+ raise ValueError("Can't merge qlora if loaded in 4bit")
+
+ else:
+ if self.load_in_8bit:
+ raise ValueError("Can't load qlora in 8bit")
+
+ if self.gptq:
+ raise ValueError("Can't load qlora if gptq")
+
+ if not self.load_in_4bit:
+ raise ValueError("Require cfg.load_in_4bit to be True for qlora")
+ return self
+
+
+class ReLoRAConfig(BaseModel):
+ """ReLoRA configuration subset"""
+
+ relora_steps: Optional[int] = None
+ relora_warmup_steps: Optional[int] = None
+ relora_anneal_steps: Optional[int] = None
+ relora_prune_ratio: Optional[float] = None
+ relora_cpu_offload: Optional[bool] = None
+
+
+class ModelInputConfig(BaseModel):
+ """model to train on configuration subset"""
+
+ base_model: str
+ base_model_config: Optional[str] = None
+ tokenizer_config: Optional[str] = None
+ tokenizer_use_fast: Optional[bool] = None
+ tokenizer_legacy: Optional[bool] = None
+ tokenizer_type: Optional[str] = Field(
+ default=None, metadata={"help": "transformers tokenizer class"}
+ )
+ model_type: Optional[str] = Field(default=None)
+ model_revision: Optional[str] = None
+ trust_remote_code: Optional[bool] = None
+
+ model_config_overrides: Optional[Dict[str, Any]] = None
+
+ @field_validator("trust_remote_code")
+ @classmethod
+ def hint_trust_remote_code(cls, trust_remote_code):
+ if trust_remote_code:
+ LOG.warning(
+ "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
+ )
+ return trust_remote_code
+
+
+class HyperparametersConfig(BaseModel):
+ """training hyperparams configuration subset"""
+
+ gradient_accumulation_steps: Optional[int] = Field(default=1)
+ micro_batch_size: Optional[int] = Field(
+ default=1,
+ metadata={"help": "per gpu micro batch size for training"},
+ )
+ batch_size: Optional[int] = Field(
+ default=None,
+ metadata={
+ "help": "Total batch size, we do not recommended setting this manually"
+ },
+ )
+ eval_batch_size: Optional[int] = Field(
+ default=None,
+ metadata={
+ "help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
+ },
+ )
+
+ train_on_inputs: Optional[bool] = None
+ group_by_length: Optional[bool] = None
+
+ learning_rate: Union[str, float]
+ weight_decay: Optional[float] = None
+ optimizer: Optional[OptimizerNames] = None
+ torchdistx_path: Optional[str] = None
+ lr_scheduler: Optional[SchedulerType] = None
+ lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
+ lr_quadratic_warmup: Optional[bool] = None
+ cosine_min_lr_ratio: Optional[float] = None
+ cosine_constant_lr_ratio: Optional[float] = None
+ lr_div_factor: Optional[float] = None
+
+ adam_epsilon: Optional[float] = None
+ adam_beta1: Optional[float] = None
+ adam_beta2: Optional[float] = None
+ max_grad_norm: Optional[float] = None
+ num_epochs: int = Field(default=1)
+
+ @field_validator("batch_size")
+ @classmethod
+ def hint_batch_size_set(cls, batch_size):
+ if batch_size:
+ LOG.warning(
+ "%s\n%s",
+ "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
+ "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
+ )
+ return batch_size
+
+
+class ModelOutputConfig(BaseModel):
+ """model save configuration subset"""
+
+ output_dir: str = Field(default="./model-out")
+ hub_model_id: Optional[str] = None
+ hub_strategy: Optional[str] = None
+ save_safetensors: Optional[bool] = None
+
+
+class MLFlowConfig(BaseModel):
+ """mlflow configuration subset"""
+
+ use_mlflow: Optional[str] = None
+ mlflow_tracking_uri: Optional[str] = None
+ mlflow_experiment_name: Optional[str] = None
+
+
+class WandbConfig(BaseModel):
+ """wandb configuration subset"""
+
+ use_wandb: Optional[bool] = None
+ wandb_name: Optional[str] = None
+ wandb_run_id: Optional[str] = None
+ wandb_mode: Optional[str] = None
+ wandb_project: Optional[str] = None
+ wandb_entity: Optional[str] = None
+ wandb_watch: Optional[str] = None
+ wandb_log_model: Optional[str] = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_wandb_run(cls, data):
+ if data.get("wandb_run_id") and not data.get("wandb_name"):
+ data["wandb_name"] = data.get("wandb_run_id")
+
+ LOG.warning(
+ "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
+ )
+
+ return data
+
+
+# pylint: disable=too-many-public-methods,too-many-ancestors
+class AxolotlInputConfig(
+ ModelInputConfig,
+ LoraConfig,
+ ReLoRAConfig,
+ HyperparametersConfig,
+ WandbConfig,
+ MLFlowConfig,
+ DeprecatedParameters,
+ BaseModel,
+):
+ """wrapper of all config options"""
+
+ strict: Optional[bool] = Field(default=False)
+ resume_from_checkpoint: Optional[str] = None
+ auto_resume_from_checkpoints: Optional[bool] = None
+ resize_token_embeddings_to_32x: Optional[bool] = None
+
+ rl: Optional[RLType] = None
+
+ datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
+ dataset_prepared_path: Optional[str] = None
+ dataset_shard_num: Optional[int] = None
+ dataset_shard_idx: Optional[int] = None
+
+ pretraining_dataset: Optional[ # type: ignore
+ conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
+ ] = Field(
+ default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
+ )
+ dataset_processes: Optional[int] = Field(default=os.cpu_count())
+ dataset_keep_in_memory: Optional[bool] = None
+ dataloader_pin_memory: Optional[bool] = None
+ dataloader_num_workers: Optional[int] = None
+ dataloader_prefetch_factor: Optional[int] = None
+ dataloader_drop_last: Optional[bool] = None
+
+ push_dataset_to_hub: Optional[str] = None
+ hf_use_auth_token: Optional[bool] = None
+
+ device: Optional[Any] = None
+ device_map: Optional[Any] = None
+ world_size: Optional[int] = None
+ local_rank: Optional[int] = None
+ ddp: Optional[bool] = None
+
+ seed: Optional[int] = None
+ ddp_timeout: Optional[int] = None
+ ddp_bucket_cap_mb: Optional[int] = None
+ ddp_broadcast_buffers: Optional[bool] = None
+ ddp_find_unused_parameters: Optional[bool] = None
+
+ eval_table_size: Optional[int] = None
+ eval_max_new_tokens: Optional[int] = None
+ do_causal_lm_eval: Optional[bool] = None
+ eval_causal_lm_metrics: Optional[List[str]] = None
+ do_bench_eval: Optional[bool] = None
+ bench_dataset: Optional[str] = None
+ metric_for_best_model: Optional[str] = None
+ greater_is_better: Optional[bool] = None
+
+ loss_watchdog_threshold: Optional[float] = None
+ loss_watchdog_patience: Optional[int] = None
+
+ bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
+ fp16: Optional[bool] = None
+ bfloat16: Optional[bool] = None # for non-AMP cases
+ float16: Optional[bool] = None # for non-AMP cases
+ tf32: Optional[bool] = None
+ float32: Optional[bool] = None
+
+ # torch_dtype: Optional[torch.dtype]
+
+ gradient_checkpointing: Optional[bool] = Field(default=False)
+ gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
+
+ unfrozen_parameters: Optional[List[str]] = None
+
+ sequence_len: int = Field(default=1024)
+ sample_packing: Optional[bool] = None
+ eval_sample_packing: Optional[bool] = None
+ pad_to_sequence_len: Optional[bool] = None
+
+ xformers_attention: Optional[bool] = None
+ sdp_attention: Optional[bool] = None
+ s2_attention: Optional[bool] = None
+ flash_attention: Optional[bool] = None
+ flash_attn_cross_entropy: Optional[bool] = None
+ flash_attn_rms_norm: Optional[bool] = None
+ flash_attn_fuse_qkv: Optional[bool] = None
+ flash_attn_fuse_mlp: Optional[bool] = None
+ flash_optimum: Optional[bool] = None
+
+ deepspeed: Optional[Union[str, Dict[str, Any]]] = None
+ fsdp: Optional[List[str]] = None
+ fsdp_config: Optional[Dict[str, Any]] = None
+
+ val_set_size: Optional[float] = Field(default=0.0)
+
+ special_tokens: Optional[SpecialTokensConfig] = None
+ tokens: Optional[List[str]] = None
+
+ torch_compile: Optional[bool] = None
+ torch_compile_backend: Optional[str] = None
+
+ max_steps: Optional[int] = None
+ warmup_steps: Optional[int] = None
+ warmup_ratio: Optional[float] = None
+ eval_steps: Optional[int] = None
+ evaluation_strategy: Optional[str] = None
+ save_steps: Optional[int] = None
+ saves_per_epoch: Optional[int] = None
+ save_strategy: Optional[str] = None
+ save_total_limit: Optional[int] = None
+ logging_steps: Optional[int] = None
+ early_stopping_patience: Optional[int] = None
+
+ neftune_noise_alpha: Optional[float] = None
+
+ max_memory: Optional[Union[int, str]] = None
+ gpu_memory_limit: Optional[Union[int, str]] = None
+
+ chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
+ default_system_message: Optional[str] = None
+
+ # INTERNALS - document for now, generally not set externally
+ is_preprocess: Optional[bool] = None
+
+ total_num_tokens: Optional[int] = None
+ total_supervised_tokens: Optional[int] = None
+ sample_packing_eff_est: Optional[float] = None
+ axolotl_config_path: Optional[str] = None
+
+ is_falcon_derived_model: Optional[bool] = Field(default=False)
+ is_llama_derived_model: Optional[bool] = Field(default=False)
+ is_mistral_derived_model: Optional[bool] = Field(default=False)
+ is_qwen_derived_model: Optional[bool] = Field(default=False)
+
+ @field_validator("datasets", mode="before")
+ @classmethod
+ def fix_sharegpt_datasets(cls, datasets):
+ for idx, ds_cfg in enumerate(datasets):
+ if not ds_cfg["type"]:
+ continue
+ if ds_cfg["type"] == "sharegpt:chat":
+ LOG.warning(
+ PendingDeprecationWarning(
+ "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
+ )
+ )
+ datasets[idx]["type"] = "sharegpt"
+ if "sharegpt_simple" in ds_cfg["type"]:
+ LOG.warning(
+ PendingDeprecationWarning(
+ "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
+ )
+ )
+ datasets[idx]["type"] = datasets[idx]["type"].replace(
+ "sharegpt_simple", "sharegpt"
+ )
+ return datasets
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_batch_size_fields(cls, data):
+ fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
+ non_empty_count = sum(1 for field in fields if data.get(field))
+
+ if non_empty_count < 2:
+ raise ValueError(f"At least two of {', '.join(fields)} must be set")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_pretraining_w_max_steps(cls, data):
+ if data.get("pretraining_dataset") and not data.get("max_steps"):
+ raise ValueError(
+ "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_pretraining_w_group_by_length(cls, data):
+ if data.get("pretraining_dataset") and data.get("group_by_length"):
+ LOG.warning(
+ "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_gptq_w_revision(cls, data):
+ if data.get("gptq") and data.get("model_revision"):
+ raise ValueError(
+ "model_revision is not supported for GPTQ models. "
+ + "Please download the model from HuggingFace Hub manually for correct branch, "
+ + "point to its path, and remove model_revision from the config."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_sample_packing_w_xformers(cls, data):
+ if data.get("sample_packing") and data.get("xformers_attention"):
+ raise ValueError(
+ "sample_packing not compatible with xformers_attention. Use flash_attention"
+ )
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_sample_packing_w_rl(cls, data):
+ if data.get("sample_packing") and data.get("rl"):
+ raise ValueError("`sample_packing: true` does not work with RLHF training")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def hint_sample_packing_padding(cls, data):
+ if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
+ LOG.warning(
+ "`pad_to_sequence_len: true` is recommended when using sample_packing"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_gas_bsz(cls, data):
+ if data.get("gradient_accumulation_steps") and data.get("batch_size"):
+ raise ValueError(
+ "please set only one of gradient_accumulation_steps or batch_size"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def hint_eval_train_mbsz(cls, data):
+ if (
+ data.get("eval_batch_size")
+ and data.get("micro_batch_size")
+ and data.get("eval_batch_size") != data.get("micro_batch_size")
+ ):
+ LOG.warning(
+ "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_push_ds_auth(cls, data):
+ if (
+ data.get("push_dataset_to_hub")
+ and data.get("hf_use_auth_token") is not True
+ ):
+ raise ValueError(
+ "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
+ )
+ return data
+
+ @model_validator(mode="after")
+ def check_falcon_fsdp(self):
+ if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp:
+ raise ValueError("FSDP is not supported for falcon models")
+ return self
+
+ @model_validator(mode="after")
+ def check_mpt_checkpointing(self):
+ if (
+ self.base_model and "mpt" in self.base_model.lower()
+ ) and self.gradient_checkpointing:
+ raise ValueError("gradient_checkpointing is not supported for MPT models")
+ return self
+
+ @model_validator(mode="after")
+ def check_better_transformers(self):
+ if self.flash_optimum is True:
+ if self.adapter:
+ LOG.warning(
+ "BetterTransformers probably doesn't work with PEFT adapters"
+ )
+ if self.fp16 or self.bf16:
+ raise ValueError("AMP is not supported with BetterTransformer")
+ if self.float16 is not True and self.bfloat16 is not True:
+ LOG.warning(
+ "You should probably set bfloat16 or float16 to true to "
+ "load the model in float16 for BetterTransformers"
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_adamw_optimizer_params(self):
+ if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
+ not self.optimizer or "adamw" not in self.optimizer.value
+ ):
+ LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
+ return self
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_saves(cls, data):
+ if (
+ data.get("save_strategy")
+ and data.get("save_steps")
+ and data.get("save_strategy") != "steps"
+ ):
+ raise ValueError(
+ "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
+ )
+ if data.get("saves_per_epoch") and data.get("save_steps"):
+ raise ValueError(
+ "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_push_save(cls, data):
+ if data.get("hub_model_id") and not (
+ data.get("save_steps") or data.get("saves_per_epoch")
+ ):
+ LOG.warning(
+ "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_evals(cls, data):
+ if (
+ data.get("evaluation_strategy")
+ and data.get("eval_steps")
+ and data.get("evaluation_strategy") != "steps"
+ ):
+ raise ValueError(
+ "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
+ )
+
+ if (
+ data.get("val_set_size") == 0
+ and (data.get("eval_steps") or data.get("evaluation_strategy"))
+ and not data.get("test_datasets")
+ ):
+ raise ValueError(
+ "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
+ )
+ if data.get("evals_per_epoch") and data.get("eval_steps"):
+ raise ValueError(
+ "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
+ )
+ if (
+ data.get("evals_per_epoch")
+ and data.get("evaluation_strategy")
+ and data.get("evaluation_strategy") != "steps"
+ ):
+ raise ValueError(
+ "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
+ )
+
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_eval_packing(cls, data):
+ if (
+ data.get("sample_packing")
+ and data.get("eval_table_size")
+ and data.get("eval_sample_packing") is not False
+ ):
+ raise ValueError(
+ "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_warmup(cls, data):
+ if data.get("warmup_steps") and data.get("warmup_ratio"):
+ raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_neftune(cls, data):
+ if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
+ data["neftune_noise_alpha"] = data["noisy_embedding_alpha"]
+ del data["noisy_embedding_alpha"]
+ elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
+ raise ValueError(
+ "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
+ )
+ return data
+
+ @field_validator("neftune_noise_alpha")
+ @classmethod
+ def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
+ if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:
+ raise ValueError("neftune_noise_alpha must be > 0.0")
+ return neftune_noise_alpha
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_frozen(cls, data):
+ if (
+ data.get("adapter")
+ and data.get("peft_layers_to_transform")
+ and data.get("unfrozen_parameters")
+ ):
+ raise ValueError(
+ "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
+ )
+
+ return data
+
+ @model_validator(mode="after")
+ def check_fft_possible_bad_config(self):
+ if (
+ # pylint: disable=too-many-boolean-expressions
+ not (self.bf16 or self.bfloat16)
+ and (self.fp16 or self.float16)
+ and not self.adapter
+ and not self.flash_attention
+ and self.sample_packing
+ ):
+ LOG.warning(
+ "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
+ )
+ # ValueError: Attempting to unscale FP16 gradients.
+ # OR
+ # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
+ return self
+
+ @model_validator(mode="after")
+ def check_fused_lora(self):
+ if self.adapter in ["lora", "qlora"] and (
+ self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
+ ):
+ raise ValueError("Fused modules are not supported with LoRA/QLoRA")
+ return self
+
+ @model_validator(mode="after")
+ def hint_lora_8bit(self):
+ loftq = (
+ self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
+ )
+ if not self.load_in_8bit and self.adapter == "lora" and not loftq:
+ LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
+ return self
+
+ @model_validator(mode="after")
+ def check_early_stopping(self):
+ if self.early_stopping_patience:
+ if not self.save_steps or self.eval_steps:
+ raise ValueError(
+ "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
+ )
+ if self.save_steps % self.eval_steps != 0:
+ raise ValueError(
+ "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_relora(self):
+ if self.relora_steps:
+ if self.adapter not in ("lora", "qlora"):
+ raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
+
+ if self.fsdp:
+ raise ValueError("fsdp not supported with ReLoRA")
+
+ if self.deepspeed:
+ raise ValueError("deepspeed not supported with ReLoRA")
+
+ if self.lr_scheduler == "one_cycle":
+ raise ValueError(
+ "ReLoRA is not compatible with the one_cycle scheduler"
+ )
+
+ if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
+ raise ValueError("Fused modules are not supported with ReLoRA")
+ return self
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_mem_mismatch(cls, data):
+ if (
+ data.get("max_memory") is not None
+ and data.get("gpu_memory_limit") is not None
+ ):
+ raise ValueError(
+ "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_use_reentrant_mismatch(cls, data):
+ if (
+ data.get("unfrozen_parameters")
+ and data.get("gradient_checkpointing_kwargs")
+ and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
+ is True
+ ):
+ # https://github.com/huggingface/transformers/issues/21381
+ raise ValueError(
+ "`use_reentrant` must be false when used with partially frozen model."
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_val_w_test_datasets(cls, data):
+ if data.get("test_datasets") and data.get("val_set_size"):
+ raise ValueError(
+ "non-zero val_set_size should not be used with test_datasets configuration"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_fsdp_w_8bit_optimizer(cls, data):
+ if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
+ raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_causal_lm_evals(cls, data):
+ if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
+ raise ValueError(
+ "do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
+ )
+
+ if data.get("eval_causal_lm_metrics"):
+ supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
+ if not isinstance(data.get("eval_causal_lm_metrics"), list):
+ raise ValueError("eval_causal_lm_metrics must be a list")
+ # only ["sacrebleu", "comet", "ter", "chrf"] supported
+ if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
+ raise ValueError(
+ f"eval_causal_lm_metrics must be one of {supported_metrics}"
+ )
+ return data
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_dataset_or_pretraining_dataset(cls, data):
+ if data.get("datasets") is None and data.get("pretraining_dataset") is None:
+ raise ValueError("either datasets or pretraining_dataset is required")
+ return data
+
+
+class AxolotlConfigWCapabilities(AxolotlInputConfig):
+ """wrapper to valdiate gpu capabilities with the configured options"""
+
+ capabilities: GPUCapabilities
+
+ @model_validator(mode="after")
+ def check_bf16(self):
+ if self.capabilities.bf16:
+ if not self.bf16 and not self.bfloat16:
+ LOG.info(
+ "bf16 support detected, but not enabled for this configuration."
+ )
+ else:
+ if (
+ not self.merge_lora
+ and not self.is_preprocess
+ and (self.bf16 is True or self.bfloat16 is True)
+ ):
+ raise ValueError(
+ "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
+ )
+ return self
+
+ @model_validator(mode="before")
+ @classmethod
+ def check_sample_packing_w_sdpa_bf16(cls, data):
+ is_sm_90: bool = (
+ data["capabilities"]
+ and data["capabilities"].get("compute_capability") == "sm_90"
+ )
+ if (
+ data.get("sample_packing")
+ and data.get("sdp_attention")
+ and (data.get("bfloat16") or data.get("bf16"))
+ and not is_sm_90
+ ):
+ # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
+ LOG.warning(
+ "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
+ "This may work on H100s."
+ )
+
+ return data
diff --git a/src/axolotl/utils/config/models/internals/__init__.py b/src/axolotl/utils/config/models/internals/__init__.py
new file mode 100644
index 000000000..dd742caf4
--- /dev/null
+++ b/src/axolotl/utils/config/models/internals/__init__.py
@@ -0,0 +1,14 @@
+"""module for gpu capabilities"""
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+class GPUCapabilities(BaseModel):
+ """model to manage the gpu capabilities statically"""
+
+ bf16: bool = Field(default=False)
+ fp8: bool = Field(default=False)
+ n_gpu: int = Field(default=1)
+ n_node: int = Field(default=1)
+ compute_capability: Optional[str] = Field(default=None)
diff --git a/src/axolotl/utils/dict.py b/src/axolotl/utils/dict.py
index 69567c604..409d088e6 100644
--- a/src/axolotl/utils/dict.py
+++ b/src/axolotl/utils/dict.py
@@ -12,4 +12,4 @@ class DictDefault(Dict):
return None
def __or__(self, other):
- return DictDefault(super().__or__(other))
+ return DictDefault(super().__ror__(other))
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index c2006997d..c94908f3d 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -104,8 +104,8 @@ def load_model_config(cfg):
)
raise err
- if cfg.model_config:
- for key, val in cfg.model_config.items():
+ if cfg.model_config_overrides:
+ for key, val in cfg.model_config_overrides.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
diff --git a/tests/test_dict.py b/tests/test_dict.py
index 8367e7c2a..2007cb085 100644
--- a/tests/test_dict.py
+++ b/tests/test_dict.py
@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
), "DictDefault should support in operator for existing keys in list"
def test_dict_or_operator(self):
- cfg = DictDefault(
+ cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
+
+ cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"key_a": {"key_b": "value_a"},
"key_c": "value_c",
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
}
)
- cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
- )
-
assert (
cfg.key_a.key_b == "value_b"
), "DictDefault should support OR operator for existing nested keys"
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index cea39d0ad..cf662d95f 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
- assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit'
- assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
+ assert decode(mt_ids) == ' [INST] lorem\nabc [/INST] ipsum [INST] 123 [/INST] sit'
+ assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
- assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum'
- assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
+ assert decode(st_ids) == ' [INST] lorem\nabc [/INST] ipsum'
+ assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])
diff --git a/tests/test_validation.py b/tests/test_validation.py
index e5a74394c..4ec544351 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -1,20 +1,39 @@
+# pylint: disable=too-many-lines
"""Module for testing the validation module"""
import logging
import os
-import unittest
from typing import Optional
import pytest
-from transformers.utils import is_torch_bf16_gpu_available
+from pydantic import ValidationError
from axolotl.utils.config import validate_config
+from axolotl.utils.config.models.input.v0_4_1 import AxolotlConfigWCapabilities
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars
-class BaseValidation(unittest.TestCase):
+@pytest.fixture(name="minimal_cfg")
+def fixture_cfg():
+ return DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ }
+ )
+
+
+class BaseValidation:
"""
Base validation module to setup the log capture
"""
@@ -27,14 +46,110 @@ class BaseValidation(unittest.TestCase):
# pylint: disable=too-many-public-methods
-class ValidationTest(BaseValidation):
+class TestValidation(BaseValidation):
"""
Test the validation module
"""
+ def test_datasets_min_length(self):
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [],
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ }
+ )
+
+ with pytest.raises(
+ ValidationError,
+ match=r".*List should have at least 1 item after validation*",
+ ):
+ validate_config(cfg)
+
+ def test_datasets_min_length_empty(self):
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ }
+ )
+
+ with pytest.raises(
+ ValueError, match=r".*either datasets or pretraining_dataset is required*"
+ ):
+ validate_config(cfg)
+
+ def test_pretrain_dataset_min_length(self):
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "pretraining_dataset": [],
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "max_steps": 100,
+ }
+ )
+
+ with pytest.raises(
+ ValidationError,
+ match=r".*List should have at least 1 item after validation*",
+ ):
+ validate_config(cfg)
+
+ def test_valid_pretrain_dataset(self):
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "pretraining_dataset": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ "max_steps": 100,
+ }
+ )
+
+ validate_config(cfg)
+
+ def test_valid_sft_dataset(self):
+ cfg = DictDefault(
+ {
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
+ "micro_batch_size": 1,
+ "gradient_accumulation_steps": 1,
+ }
+ )
+
+ validate_config(cfg)
+
def test_batch_size_unused_warning(self):
cfg = DictDefault(
{
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
+ "micro_batch_size": 4,
"batch_size": 32,
}
)
@@ -43,104 +158,163 @@ class ValidationTest(BaseValidation):
validate_config(cfg)
assert "batch_size is not recommended" in self._caplog.records[0].message
- def test_qlora(self):
- base_cfg = DictDefault(
+ def test_batch_size_more_params(self):
+ cfg = DictDefault(
{
- "adapter": "qlora",
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
+ "batch_size": 32,
}
)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "load_in_8bit": True,
- }
+ with pytest.raises(ValueError, match=r".*At least two of*"):
+ validate_config(cfg)
+
+ def test_qlora(self, minimal_cfg):
+ base_cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ }
+ )
+ | minimal_cfg
+ )
+
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "load_in_8bit": True,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "gptq": True,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "gptq": True,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "load_in_4bit": False,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "load_in_4bit": False,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "load_in_4bit": True,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "load_in_4bit": True,
+ }
+ )
+ | base_cfg
)
validate_config(cfg)
- def test_qlora_merge(self):
- base_cfg = DictDefault(
- {
- "adapter": "qlora",
- "merge_lora": True,
- }
+ def test_qlora_merge(self, minimal_cfg):
+ base_cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ "merge_lora": True,
+ }
+ )
+ | minimal_cfg
)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "load_in_8bit": True,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "load_in_8bit": True,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*8bit.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "gptq": True,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "gptq": True,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*gptq.*"):
validate_config(cfg)
- cfg = base_cfg | DictDefault( # pylint: disable=unsupported-binary-operation
- {
- "load_in_4bit": True,
- }
+ cfg = (
+ DictDefault( # pylint: disable=unsupported-binary-operation
+ {
+ "load_in_4bit": True,
+ }
+ )
+ | base_cfg
)
with pytest.raises(ValueError, match=r".*4bit.*"):
validate_config(cfg)
- def test_hf_use_auth_token(self):
- cfg = DictDefault(
- {
- "push_dataset_to_hub": "namespace/repo",
- }
+ def test_hf_use_auth_token(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "push_dataset_to_hub": "namespace/repo",
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(ValueError, match=r".*hf_use_auth_token.*"):
validate_config(cfg)
- cfg = DictDefault(
- {
- "push_dataset_to_hub": "namespace/repo",
- "hf_use_auth_token": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "push_dataset_to_hub": "namespace/repo",
+ "hf_use_auth_token": True,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
def test_gradient_accumulations_or_batch_size(self):
cfg = DictDefault(
{
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
+ "learning_rate": 0.000001,
+ "datasets": [
+ {
+ "path": "mhenrichsen/alpaca_2k_test",
+ "type": "alpaca",
+ }
+ ],
"gradient_accumulation_steps": 1,
"batch_size": 1,
}
@@ -151,75 +325,75 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "batch_size": 1,
- }
- )
-
- validate_config(cfg)
-
- cfg = DictDefault(
- {
- "gradient_accumulation_steps": 1,
- }
- )
-
- validate_config(cfg)
-
- def test_falcon_fsdp(self):
+ def test_falcon_fsdp(self, minimal_cfg):
regex_exp = r".*FSDP is not supported for falcon models.*"
# Check for lower-case
- cfg = DictDefault(
- {
- "base_model": "tiiuae/falcon-7b",
- "fsdp": ["full_shard", "auto_wrap"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "base_model": "tiiuae/falcon-7b",
+ "fsdp": ["full_shard", "auto_wrap"],
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
# Check for upper-case
- cfg = DictDefault(
- {
- "base_model": "Falcon-7b",
- "fsdp": ["full_shard", "auto_wrap"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "base_model": "Falcon-7b",
+ "fsdp": ["full_shard", "auto_wrap"],
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
- cfg = DictDefault(
- {
- "base_model": "tiiuae/falcon-7b",
- }
+ cfg = (
+ DictDefault(
+ {
+ "base_model": "tiiuae/falcon-7b",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_mpt_gradient_checkpointing(self):
+ def test_mpt_gradient_checkpointing(self, minimal_cfg):
regex_exp = r".*gradient_checkpointing is not supported for MPT models*"
# Check for lower-case
- cfg = DictDefault(
- {
- "base_model": "mosaicml/mpt-7b",
- "gradient_checkpointing": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "base_model": "mosaicml/mpt-7b",
+ "gradient_checkpointing": True,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
- def test_flash_optimum(self):
- cfg = DictDefault(
- {
- "flash_optimum": True,
- "adapter": "lora",
- }
+ def test_flash_optimum(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "flash_optimum": True,
+ "adapter": "lora",
+ "bf16": False,
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
@@ -230,10 +404,14 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
- cfg = DictDefault(
- {
- "flash_optimum": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "flash_optimum": True,
+ "bf16": False,
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
@@ -243,34 +421,43 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
- cfg = DictDefault(
- {
- "flash_optimum": True,
- "fp16": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "flash_optimum": True,
+ "fp16": True,
+ }
+ )
+ | minimal_cfg
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
- cfg = DictDefault(
- {
- "flash_optimum": True,
- "bf16": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "flash_optimum": True,
+ "bf16": True,
+ }
+ )
+ | minimal_cfg
)
regex_exp = r".*AMP is not supported.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)
- def test_adamw_hyperparams(self):
- cfg = DictDefault(
- {
- "optimizer": None,
- "adam_epsilon": 0.0001,
- }
+ def test_adamw_hyperparams(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "optimizer": None,
+ "adam_epsilon": 0.0001,
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
@@ -281,11 +468,14 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
- cfg = DictDefault(
- {
- "optimizer": "adafactor",
- "adam_beta1": 0.0001,
- }
+ cfg = (
+ DictDefault(
+ {
+ "optimizer": "adafactor",
+ "adam_beta1": 0.0001,
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
@@ -296,30 +486,39 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
- cfg = DictDefault(
- {
- "optimizer": "adamw_bnb_8bit",
- "adam_beta1": 0.9,
- "adam_beta2": 0.99,
- "adam_epsilon": 0.0001,
- }
+ cfg = (
+ DictDefault(
+ {
+ "optimizer": "adamw_bnb_8bit",
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.99,
+ "adam_epsilon": 0.0001,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "optimizer": "adafactor",
- }
+ cfg = (
+ DictDefault(
+ {
+ "optimizer": "adafactor",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_deprecated_packing(self):
- cfg = DictDefault(
- {
- "max_packed_sequence_len": 1024,
- }
+ def test_deprecated_packing(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "max_packed_sequence_len": 1024,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
DeprecationWarning,
@@ -327,12 +526,15 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- def test_packing(self):
- cfg = DictDefault(
- {
- "sample_packing": True,
- "pad_to_sequence_len": None,
- }
+ def test_packing(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "sample_packing": True,
+ "pad_to_sequence_len": None,
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
@@ -342,62 +544,79 @@ class ValidationTest(BaseValidation):
for record in self._caplog.records
)
- @pytest.mark.skipif(
- is_torch_bf16_gpu_available(),
- reason="test should only run on gpus w/o bf16 support",
- )
- def test_merge_lora_no_bf16_fail(self):
+ def test_merge_lora_no_bf16_fail(self, minimal_cfg):
"""
This is assumed to be run on a CPU machine, so bf16 is not supported.
"""
- cfg = DictDefault(
- {
- "bf16": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "bf16": True,
+ "capabilities": {"bf16": False},
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
- validate_config(cfg)
+ AxolotlConfigWCapabilities(**cfg.to_dict())
- cfg = DictDefault(
- {
- "bf16": True,
- "merge_lora": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "bf16": True,
+ "merge_lora": True,
+ "capabilities": {"bf16": False},
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_sharegpt_deprecation(self):
- cfg = DictDefault(
- {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
+ def test_sharegpt_deprecation(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
assert any(
"`type: sharegpt:chat` will soon be deprecated." in record.message
for record in self._caplog.records
)
- assert cfg.datasets[0].type == "sharegpt"
+ assert new_cfg.datasets[0].type == "sharegpt"
- cfg = DictDefault(
- {"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
+ cfg = (
+ DictDefault(
+ {
+ "datasets": [
+ {"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}
+ ]
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
assert any(
"`type: sharegpt_simple` will soon be deprecated." in record.message
for record in self._caplog.records
)
- assert cfg.datasets[0].type == "sharegpt:load_role"
+ assert new_cfg.datasets[0].type == "sharegpt:load_role"
- def test_no_conflict_save_strategy(self):
- cfg = DictDefault(
- {
- "save_strategy": "epoch",
- "save_steps": 10,
- }
+ def test_no_conflict_save_strategy(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "save_strategy": "epoch",
+ "save_steps": 10,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -405,11 +624,14 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "save_strategy": "no",
- "save_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "save_strategy": "no",
+ "save_steps": 10,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -417,45 +639,60 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "save_strategy": "steps",
- }
+ cfg = (
+ DictDefault(
+ {
+ "save_strategy": "steps",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "save_strategy": "steps",
- "save_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "save_strategy": "steps",
+ "save_steps": 10,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "save_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "save_steps": 10,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "save_strategy": "no",
- }
+ cfg = (
+ DictDefault(
+ {
+ "save_strategy": "no",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_no_conflict_eval_strategy(self):
- cfg = DictDefault(
- {
- "evaluation_strategy": "epoch",
- "eval_steps": 10,
- }
+ def test_no_conflict_eval_strategy(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "eval_steps": 10,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -463,11 +700,14 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "no",
- "eval_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "no",
+ "eval_steps": 10,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -475,44 +715,59 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "steps",
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "steps",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "steps",
- "eval_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "steps",
+ "eval_steps": 10,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "eval_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "eval_steps": 10,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "no",
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "no",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "epoch",
- "val_set_size": 0,
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "val_set_size": 0,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -521,11 +776,14 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "eval_steps": 10,
- "val_set_size": 0,
- }
+ cfg = (
+ DictDefault(
+ {
+ "eval_steps": 10,
+ "val_set_size": 0,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -534,38 +792,50 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "val_set_size": 0,
- }
+ cfg = (
+ DictDefault(
+ {
+ "val_set_size": 0,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "eval_steps": 10,
- "val_set_size": 0.01,
- }
+ cfg = (
+ DictDefault(
+ {
+ "eval_steps": 10,
+ "val_set_size": 0.01,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "evaluation_strategy": "epoch",
- "val_set_size": 0.01,
- }
+ cfg = (
+ DictDefault(
+ {
+ "evaluation_strategy": "epoch",
+ "val_set_size": 0.01,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_eval_table_size_conflict_eval_packing(self):
- cfg = DictDefault(
- {
- "sample_packing": True,
- "eval_table_size": 100,
- }
+ def test_eval_table_size_conflict_eval_packing(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "sample_packing": True,
+ "eval_table_size": 100,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -573,39 +843,51 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "sample_packing": True,
- "eval_sample_packing": False,
- }
+ cfg = (
+ DictDefault(
+ {
+ "sample_packing": True,
+ "eval_sample_packing": False,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "sample_packing": False,
- "eval_table_size": 100,
- }
+ cfg = (
+ DictDefault(
+ {
+ "sample_packing": False,
+ "eval_table_size": 100,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "sample_packing": True,
- "eval_table_size": 100,
- "eval_sample_packing": False,
- }
+ cfg = (
+ DictDefault(
+ {
+ "sample_packing": True,
+ "eval_table_size": 100,
+ "eval_sample_packing": False,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_load_in_x_bit_without_adapter(self):
- cfg = DictDefault(
- {
- "load_in_4bit": True,
- }
+ def test_load_in_x_bit_without_adapter(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "load_in_4bit": True,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -614,10 +896,13 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "load_in_8bit": True,
- }
+ cfg = (
+ DictDefault(
+ {
+ "load_in_8bit": True,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -626,30 +911,39 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "load_in_4bit": True,
- "adapter": "qlora",
- }
+ cfg = (
+ DictDefault(
+ {
+ "load_in_4bit": True,
+ "adapter": "qlora",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "load_in_8bit": True,
- "adapter": "lora",
- }
+ cfg = (
+ DictDefault(
+ {
+ "load_in_8bit": True,
+ "adapter": "lora",
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_warmup_step_no_conflict(self):
- cfg = DictDefault(
- {
- "warmup_steps": 10,
- "warmup_ratio": 0.1,
- }
+ def test_warmup_step_no_conflict(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "warmup_steps": 10,
+ "warmup_ratio": 0.1,
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -658,29 +952,40 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- cfg = DictDefault(
- {
- "warmup_steps": 10,
- }
+ cfg = (
+ DictDefault(
+ {
+ "warmup_steps": 10,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- cfg = DictDefault(
- {
- "warmup_ratio": 0.1,
- }
+ cfg = (
+ DictDefault(
+ {
+ "warmup_ratio": 0.1,
+ }
+ )
+ | minimal_cfg
)
validate_config(cfg)
- def test_unfrozen_parameters_w_peft_layers_to_transform(self):
- cfg = DictDefault(
- {
- "adapter": "lora",
- "unfrozen_parameters": ["model.layers.2[0-9]+.block_sparse_moe.gate.*"],
- "peft_layers_to_transform": [0, 1],
- }
+ def test_unfrozen_parameters_w_peft_layers_to_transform(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "adapter": "lora",
+ "unfrozen_parameters": [
+ "model.layers.2[0-9]+.block_sparse_moe.gate.*"
+ ],
+ "peft_layers_to_transform": [0, 1],
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -689,8 +994,8 @@ class ValidationTest(BaseValidation):
):
validate_config(cfg)
- def test_hub_model_id_save_value_warns(self):
- cfg = DictDefault({"hub_model_id": "test"})
+ def test_hub_model_id_save_value_warns(self, minimal_cfg):
+ cfg = DictDefault({"hub_model_id": "test"}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
@@ -698,22 +1003,25 @@ class ValidationTest(BaseValidation):
"set without any models being saved" in self._caplog.records[0].message
)
- def test_hub_model_id_save_value(self):
- cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
+ def test_hub_model_id_save_value(self, minimal_cfg):
+ cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
-class ValidationCheckModelConfig(BaseValidation):
+class TestValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
"""
- def test_llama_add_tokens_adapter(self):
- cfg = DictDefault(
- {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
+ def test_llama_add_tokens_adapter(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
+ )
+ | minimal_cfg
)
model_config = DictDefault({"model_type": "llama"})
@@ -723,13 +1031,16 @@ class ValidationCheckModelConfig(BaseValidation):
):
check_model_config(cfg, model_config)
- cfg = DictDefault(
- {
- "adapter": "qlora",
- "load_in_4bit": True,
- "tokens": ["<|imstart|>"],
- "lora_modules_to_save": ["embed_tokens"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ "load_in_4bit": True,
+ "tokens": ["<|imstart|>"],
+ "lora_modules_to_save": ["embed_tokens"],
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -738,20 +1049,26 @@ class ValidationCheckModelConfig(BaseValidation):
):
check_model_config(cfg, model_config)
- cfg = DictDefault(
- {
- "adapter": "qlora",
- "load_in_4bit": True,
- "tokens": ["<|imstart|>"],
- "lora_modules_to_save": ["embed_tokens", "lm_head"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ "load_in_4bit": True,
+ "tokens": ["<|imstart|>"],
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
+ }
+ )
+ | minimal_cfg
)
check_model_config(cfg, model_config)
- def test_phi_add_tokens_adapter(self):
- cfg = DictDefault(
- {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
+ def test_phi_add_tokens_adapter(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
+ )
+ | minimal_cfg
)
model_config = DictDefault({"model_type": "phi"})
@@ -761,13 +1078,16 @@ class ValidationCheckModelConfig(BaseValidation):
):
check_model_config(cfg, model_config)
- cfg = DictDefault(
- {
- "adapter": "qlora",
- "load_in_4bit": True,
- "tokens": ["<|imstart|>"],
- "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ "load_in_4bit": True,
+ "tokens": ["<|imstart|>"],
+ "lora_modules_to_save": ["embd.wte", "lm_head.linear"],
+ }
+ )
+ | minimal_cfg
)
with pytest.raises(
@@ -776,66 +1096,78 @@ class ValidationCheckModelConfig(BaseValidation):
):
check_model_config(cfg, model_config)
- cfg = DictDefault(
- {
- "adapter": "qlora",
- "load_in_4bit": True,
- "tokens": ["<|imstart|>"],
- "lora_modules_to_save": ["embed_tokens", "lm_head"],
- }
+ cfg = (
+ DictDefault(
+ {
+ "adapter": "qlora",
+ "load_in_4bit": True,
+ "tokens": ["<|imstart|>"],
+ "lora_modules_to_save": ["embed_tokens", "lm_head"],
+ }
+ )
+ | minimal_cfg
)
check_model_config(cfg, model_config)
-class ValidationWandbTest(BaseValidation):
+class TestValidationWandb(BaseValidation):
"""
Validation test for wandb
"""
- def test_wandb_set_run_id_to_name(self):
- cfg = DictDefault(
- {
- "wandb_run_id": "foo",
- }
+ def test_wandb_set_run_id_to_name(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "wandb_run_id": "foo",
+ }
+ )
+ | minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
- assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
+ assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id == "foo"
- cfg = DictDefault(
- {
- "wandb_name": "foo",
- }
+ cfg = (
+ DictDefault(
+ {
+ "wandb_name": "foo",
+ }
+ )
+ | minimal_cfg
)
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
- assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
+ assert new_cfg.wandb_name == "foo" and new_cfg.wandb_run_id is None
- def test_wandb_sets_env(self):
- cfg = DictDefault(
- {
- "wandb_project": "foo",
- "wandb_name": "bar",
- "wandb_run_id": "bat",
- "wandb_entity": "baz",
- "wandb_mode": "online",
- "wandb_watch": "false",
- "wandb_log_model": "checkpoint",
- }
+ def test_wandb_sets_env(self, minimal_cfg):
+ cfg = (
+ DictDefault(
+ {
+ "wandb_project": "foo",
+ "wandb_name": "bar",
+ "wandb_run_id": "bat",
+ "wandb_entity": "baz",
+ "wandb_mode": "online",
+ "wandb_watch": "false",
+ "wandb_log_model": "checkpoint",
+ }
+ )
+ | minimal_cfg
)
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
- setup_wandb_env_vars(cfg)
+ setup_wandb_env_vars(new_cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
@@ -855,24 +1187,27 @@ class ValidationWandbTest(BaseValidation):
os.environ.pop("WANDB_LOG_MODEL", None)
os.environ.pop("WANDB_DISABLED", None)
- def test_wandb_set_disabled(self):
- cfg = DictDefault({})
+ def test_wandb_set_disabled(self, minimal_cfg):
+ cfg = DictDefault({}) | minimal_cfg
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
- setup_wandb_env_vars(cfg)
+ setup_wandb_env_vars(new_cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
- cfg = DictDefault(
- {
- "wandb_project": "foo",
- }
+ cfg = (
+ DictDefault(
+ {
+ "wandb_project": "foo",
+ }
+ )
+ | minimal_cfg
)
- validate_config(cfg)
+ new_cfg = validate_config(cfg)
- setup_wandb_env_vars(cfg)
+ setup_wandb_env_vars(new_cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"